mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 21:56:25 -04:00
Compare commits
113 Commits
debug
...
account-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e24916dac | ||
|
|
875b8d662c | ||
|
|
41b4e3177a | ||
|
|
3186876d5e | ||
|
|
13eae9bc93 | ||
|
|
de99624610 | ||
|
|
accada3311 | ||
|
|
71af7edd05 | ||
|
|
e17d8127e3 | ||
|
|
ea51ce876e | ||
|
|
2115e2c3f0 | ||
|
|
7a6ca3ee37 | ||
|
|
70b4628b5a | ||
|
|
f42c775e45 | ||
|
|
24970a1746 | ||
|
|
de3e67e7ae | ||
|
|
7be83a0199 | ||
|
|
7d0331f41e | ||
|
|
7af55fbd71 | ||
|
|
7fa1bbc722 | ||
|
|
66d8bbf8e2 | ||
|
|
6ea98f0ce7 | ||
|
|
6a456c52bf | ||
|
|
4d00207c3b | ||
|
|
2de0777f7a | ||
|
|
0ee56e14d9 | ||
|
|
20fc8e879e | ||
|
|
48edfa601f | ||
|
|
a2a49bdd47 | ||
|
|
a2fb274b86 | ||
|
|
a61e9da3e9 | ||
|
|
f6f7260897 | ||
|
|
c557c98390 | ||
|
|
7d849a92c0 | ||
|
|
f5e7449d01 | ||
|
|
8420a52563 | ||
|
|
6315644065 | ||
|
|
ef55b9eccc | ||
|
|
218345e0ff | ||
|
|
4b943c34b7 | ||
|
|
560190519d | ||
|
|
9bc8e6e29e | ||
|
|
9872bee41d | ||
|
|
3a915decd7 | ||
|
|
50e6389a1d | ||
|
|
bbaee18cd5 | ||
|
|
32d1b2d602 | ||
|
|
2a59f04540 | ||
|
|
446de5e2bc | ||
|
|
147971fdfe | ||
|
|
ed259a6a03 | ||
|
|
a3abc211b3 | ||
|
|
20a5afc359 | ||
|
|
00023bf110 | ||
|
|
2806d73161 | ||
|
|
2d7f08c609 | ||
|
|
0c0fd380bd | ||
|
|
ffce48ca5f | ||
|
|
d23b5c892b | ||
|
|
113c21b0e1 | ||
|
|
ab00c41dad | ||
|
|
664d1388aa | ||
|
|
010a8bfdc1 | ||
|
|
6cb697eed6 | ||
|
|
e0bed2b0fb | ||
|
|
601d429d82 | ||
|
|
30f025e7dd | ||
|
|
b4d7605147 | ||
|
|
d54b6967ce | ||
|
|
174e07fefd | ||
|
|
871500c5cc | ||
|
|
cc04aef7b4 | ||
|
|
3ed8b9cee9 | ||
|
|
08b6e9d647 | ||
|
|
bdeb95c58c | ||
|
|
6dc185e141 | ||
|
|
7100be83cd | ||
|
|
67ce14eaea | ||
|
|
d58cf50127 | ||
|
|
40af1a50e3 | ||
|
|
ac05f69131 | ||
|
|
8126d95316 | ||
|
|
0a70e4c5d4 | ||
|
|
106fc75936 | ||
|
|
669904cd06 | ||
|
|
f8b5eedd38 | ||
|
|
931521d505 | ||
|
|
1a5f3c653c | ||
|
|
78044c226d | ||
|
|
389c9619af | ||
|
|
4be826450b | ||
|
|
738387f2de | ||
|
|
baf0678ceb | ||
|
|
7fef8f6758 | ||
|
|
6829a64a2d | ||
|
|
cbf500024f | ||
|
|
509e184e10 | ||
|
|
3e88b7c56e | ||
|
|
b952d8693d | ||
|
|
5b46cc8e9c | ||
|
|
a9d06b883f | ||
|
|
5f06b202c3 | ||
|
|
0eb99c266a | ||
|
|
bac95ace18 | ||
|
|
9812de853b | ||
|
|
ad4f0a6fdf | ||
|
|
4c758c6e52 | ||
|
|
ec5095ba6b | ||
|
|
49a54624f8 | ||
|
|
729bcf2b01 | ||
|
|
a0cdb58303 | ||
|
|
39c99781cb | ||
|
|
01f24907c5 |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [netbirdio]
|
||||
1
.github/workflows/golang-test-linux.yml
vendored
1
.github/workflows/golang-test-linux.yml
vendored
@@ -13,6 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
|
||||
@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
|
||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||
|
||||
"128.0.0.0", "8000::", // 2nd split subnet for default routes
|
||||
}
|
||||
|
||||
if slices.Contains(wellKnown, addr.String()) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -13,10 +14,11 @@ import (
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
serverInstanceMu sync.Mutex
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
|
||||
@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstanceMu.Lock()
|
||||
p.serverInstance = serverInstance
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
in := new(proto.DownRequest)
|
||||
_, err := p.serverInstance.Down(p.ctx, in)
|
||||
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
log.Errorf("failed to stop daemon: %v", err)
|
||||
}
|
||||
}
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
p.cancel()
|
||||
|
||||
|
||||
@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
|
||||
@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -18,22 +18,24 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpPre = "jump-pre"
|
||||
jumpNat = "jump-nat"
|
||||
matchSet = "--match-set"
|
||||
)
|
||||
|
||||
@@ -323,24 +325,25 @@ func (r *router) Reset() error {
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
err := r.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
if err := r.cleanJumpRules(); err != nil {
|
||||
return fmt.Errorf("clean jump rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTPRE, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s, error: %v", chain, err)
|
||||
return err
|
||||
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
} else if ok {
|
||||
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
||||
return err
|
||||
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -349,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
if err := r.createAndSetupChain(chain); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chain, err)
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
} {
|
||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,6 +369,10 @@ func (r *router) createContainers() error {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add static nat rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
@@ -366,6 +380,32 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
"!", "-o", "lo",
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||
return fmt.Errorf("add outbound masquerade rule: %v", err)
|
||||
}
|
||||
r.rules["static-nat-outbound"] = rule1
|
||||
|
||||
// Second rule for return traffic masquerade
|
||||
rule2 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
"-o", r.wgIface.Name(),
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||
return fmt.Errorf("add return masquerade rule: %v", err)
|
||||
}
|
||||
r.rules["static-nat-return"] = rule2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
@@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
if chain == chainRTNAT {
|
||||
switch chain {
|
||||
case chainRTNAT:
|
||||
return tableNat
|
||||
case chainRTPRE:
|
||||
return tableMangle
|
||||
default:
|
||||
return tableFilter
|
||||
}
|
||||
return tableFilter
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
@@ -398,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTNAT}
|
||||
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
// Jump to NAT chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat jump rule: %v", err)
|
||||
}
|
||||
r.rules[ipv4Nat] = rule
|
||||
r.rules[jumpNat] = natRule
|
||||
|
||||
// Jump to prerouting chain
|
||||
preRule := []string{"-j", chainRTPRE}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpPre] = preRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
rule, found := r.rules[ipv4Nat]
|
||||
if found {
|
||||
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
table := tableNat
|
||||
chain := chainPOSTROUTING
|
||||
if ruleKey == jumpPre {
|
||||
table = tableMangle
|
||||
chain = chainPREROUTING
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -424,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
rule := []string{"-i", r.wgIface.Name()}
|
||||
if pair.Inverse {
|
||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||
}
|
||||
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", pair.Source.String(),
|
||||
"-d", pair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -444,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nat rule %s not found", ruleKey)
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -482,16 +555,6 @@ func (r *router) updateState() {
|
||||
}
|
||||
}
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
lointdir := "-o"
|
||||
if inverse {
|
||||
intdir = "-o"
|
||||
lointdir = "-i"
|
||||
}
|
||||
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
@@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
// Now 5 rules:
|
||||
// 1. established rule in forward chain
|
||||
// 2. jump rule to NAT chain
|
||||
// 3. jump rule to PRE chain
|
||||
// 4. static outbound masquerade rule
|
||||
// 5. static return masquerade rule
|
||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
require.True(t, exists, "postrouting jump rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||
require.True(t, exists, "prerouting jump rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "adding NAT rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
@@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset iptables manager: %s", err)
|
||||
}
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
require.NoError(t, err, "marking rule should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
||||
require.True(t, foundNat, "nat rule should exist in the map")
|
||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[natRuleKey]
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", testCase.InputPair.Source.String(),
|
||||
"-d", testCase.InputPair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "income nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||
require.True(t, exists, "marking rule should be created")
|
||||
foundRule, found := manager.rules[natRuleKey]
|
||||
require.True(t, found, "marking rule should exist in the map")
|
||||
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[inNatRuleKey]
|
||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||
require.False(t, exists, "marking rule should not be created")
|
||||
_, found := manager.rules[natRuleKey]
|
||||
require.False(t, found, "marking rule should not exist in the map")
|
||||
}
|
||||
|
||||
// Check inverse rule
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", inversePair.Source.String(),
|
||||
"-d", inversePair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "inverse marking rule should be created")
|
||||
foundRule, found := manager.rules[inverseRuleKey]
|
||||
require.True(t, found, "inverse marking rule should exist in the map")
|
||||
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "inverse marking rule should not be created")
|
||||
_, found := manager.rules[inverseRuleKey]
|
||||
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
@@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule without error")
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", testCase.InputPair.Source.String(),
|
||||
"-d", testCase.InputPair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
require.False(t, exists, "marking rule should not exist")
|
||||
|
||||
_, found := manager.rules[natRuleKey]
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
require.False(t, found, "marking rule should not exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "income nat rule should not exist")
|
||||
// Check inverse rule removal
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", inversePair.Source.String(),
|
||||
"-d", inversePair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
require.False(t, exists, "inverse marking rule should not exist")
|
||||
|
||||
_, found = manager.rules[inverseRuleKey]
|
||||
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
|
||||
@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -124,7 +125,6 @@ func (r *router) createContainers() error {
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
@@ -133,6 +133,21 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Chain is created by acl manager
|
||||
// TODO: move creation to a common place
|
||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: r.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
@@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
dir := expr.MetaKeyIIFNAME
|
||||
notDir := expr.MetaKeyOIFNAME
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
dir = expr.MetaKeyOIFNAME
|
||||
notDir = expr.MetaKeyIIFNAME
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
lo := ifname("lo")
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: dir,
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: notDir,
|
||||
Register: 1,
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: lo,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
|
||||
// interface matching
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: op,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Counter{}, &expr.Masq{},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Chain: r.chains[chainNamePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPostroutingRules adds the masquerade rules
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
})
|
||||
|
||||
// Second masquerade rule for traffic going out through WireGuard interface
|
||||
exprs2 := []expr.Any{
|
||||
// Match on the second fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
},
|
||||
|
||||
// Match WireGuard interface
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -723,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -749,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -32,100 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(table, ifaceMock)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
require.NoError(t, manager.init(table))
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
rtr := manager.router
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
|
||||
defer func(manager *router, pair firewall.RouterPair) {
|
||||
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
||||
}(manager, testCase.InputPair)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
|
||||
})
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
// Build expected expressions for connection tracking
|
||||
conntrackExprs := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
// Build interface matching expression
|
||||
ifaceExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
testingExpression := append(conntrackExprs, ifaceExprs...)
|
||||
testingExpression = append(testingExpression, sourceExp...)
|
||||
testingExpression = append(testingExpression, destExp...)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||
found = 1
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
// Compare expressions up to the mark setting expressions
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -135,68 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(table, ifaceMock)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
require.NoError(t, manager.init(table))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
||||
rtr := manager.router
|
||||
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
// First add the NAT rule using the router's method
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(inNatRuleKey),
|
||||
})
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
}
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found, "NAT rule should exist before removal")
|
||||
|
||||
// Now remove the rule
|
||||
err = rtr.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error when removing rule")
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, found, "NAT rule should not exist after removal")
|
||||
|
||||
// Verify the static postrouting rules still exist
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
|
||||
require.NoError(t, err, "should list postrouting rules")
|
||||
foundCounter := false
|
||||
for _, rule := range rules {
|
||||
for _, e := range rule.Exprs {
|
||||
if _, ok := e.(*expr.Counter); ok {
|
||||
foundCounter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundCounter {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundCounter, "static postrouting rule should remain")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
@@ -24,8 +25,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -154,7 +155,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
@@ -166,16 +167,30 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
@@ -191,11 +206,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
@@ -273,3 +289,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
return p.remoteConn.Close()
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
@@ -104,8 +108,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
buf := make([]byte, 1500)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
|
||||
@@ -309,6 +309,11 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
return
|
||||
}
|
||||
|
||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||
conn.log.Errorf("remote ICE connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
@@ -437,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
@@ -460,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
@@ -731,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
21
client/internal/peer/nilcheck.go
Normal file
21
client/internal/peer/nilcheck.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
|
||||
if conn == nil {
|
||||
log.Errorf("ice conn is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
if conn.RemoteAddr() == nil {
|
||||
log.Errorf("ICE remote address is nil")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
|
||||
func (s *State) GetRoutes() map[string]struct{} {
|
||||
s.Mux.RLock()
|
||||
defer s.Mux.RUnlock()
|
||||
return s.routes
|
||||
return maps.Clone(s.routes)
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.GetRoutes() != nil {
|
||||
peerState.SetRoutes(receivedState.GetRoutes())
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
@@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.AddRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.DeleteRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
ch, found := d.changeNotify[peer]
|
||||
if !found || ch == nil {
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
if found {
|
||||
return ch
|
||||
}
|
||||
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
@@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
ch, found := d.changeNotify[peerID]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(d.changeNotify, peerID)
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
select {
|
||||
|
||||
@@ -57,6 +57,9 @@ type WorkerICE struct {
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||
@@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
|
||||
return
|
||||
}
|
||||
|
||||
err := w.agent.Close()
|
||||
if err != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
|
||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
|
||||
w.muxAgent.Lock()
|
||||
agentCancel()
|
||||
_ = agent.Close()
|
||||
w.agent = nil
|
||||
|
||||
w.muxAgent.Unlock()
|
||||
switch state {
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
}
|
||||
w.closeAgent(agentCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
@@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
cancel()
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
// wait local endpoint configuration
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
tempScore = float64(metricDiff) * 10
|
||||
}
|
||||
|
||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||
latency := time.Second
|
||||
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
|
||||
latency := 999 * time.Millisecond
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||
}
|
||||
|
||||
// avoid negative tempScore on the higher latency calculation
|
||||
if latency > 1*time.Second {
|
||||
latency = 999 * time.Millisecond
|
||||
}
|
||||
|
||||
// higher latency is worse score
|
||||
tempScore += 1 - latency.Seconds()
|
||||
|
||||
if !peerStatus.relayed {
|
||||
@@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
@@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
for _, r := range c.routes {
|
||||
_, found := c.routePeersNotifiers[r.Peer]
|
||||
if !found {
|
||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
closerChan := make(chan struct{})
|
||||
c.routePeersNotifiers[r.Peer] = closerChan
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
@@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
@@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
} else {
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
@@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.addStateRoute()
|
||||
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add peer state route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
|
||||
@@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
|
||||
7
client/server/panic_generic.go
Normal file
7
client/server/panic_generic.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
func handlePanicLog() error {
|
||||
return nil
|
||||
}
|
||||
83
client/server/panic_windows.go
Normal file
83
client/server/panic_windows.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
|
||||
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
|
||||
stdErrorHandle = ^uintptr(11)
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
|
||||
setStdHandleFn = kernel32.NewProc("SetStdHandle")
|
||||
)
|
||||
|
||||
func handlePanicLog() error {
|
||||
logPath := os.Getenv(windowsPanicLogEnvVar)
|
||||
if logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
return fmt.Errorf("create panic log directory: %w", err)
|
||||
}
|
||||
if err := util.EnforcePermission(logPath); err != nil {
|
||||
return fmt.Errorf("enforce permission on panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Open log file with append mode
|
||||
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stderr to the file
|
||||
if err = redirectStderr(f); err != nil {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after redirect error: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("redirect stderr: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("successfully configured panic logging to: %s", logPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// redirectStderr redirects stderr to the provided file
|
||||
func redirectStderr(f *os.File) error {
|
||||
// Get the current process's stderr handle
|
||||
if err := setStdHandle(f); err != nil {
|
||||
return fmt.Errorf("failed to set stderr handle: %w", err)
|
||||
}
|
||||
|
||||
// Also set os.Stderr for Go's standard library
|
||||
os.Stderr = f
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setStdHandle(f *os.File) error {
|
||||
handle := f.Fd()
|
||||
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
|
||||
if r0 == 0 {
|
||||
if e1 != nil {
|
||||
return e1
|
||||
}
|
||||
return syscall.EINVAL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -97,6 +97,10 @@ func (s *Server) Start() error {
|
||||
defer s.mutex.Unlock()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
log.Warnf("failed to redirect stderr: %v", err)
|
||||
}
|
||||
|
||||
if err := restoreResidualState(s.rootCtx); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
@@ -622,6 +626,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
|
||||
if s.actCancel == nil {
|
||||
return nil, fmt.Errorf("service is not up")
|
||||
}
|
||||
|
||||
3
client/testdata/store.sql
vendored
3
client/testdata/store.sql
vendored
@@ -31,6 +31,9 @@ INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-0
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
|
||||
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
COMMIT;
|
||||
|
||||
126
funding.json
Normal file
126
funding.json
Normal file
@@ -0,0 +1,126 @@
|
||||
{
|
||||
"version": "v1.0.0",
|
||||
"entity": {
|
||||
"type": "organisation",
|
||||
"role": "owner",
|
||||
"name": "NetBird GmbH",
|
||||
"email": "hello@netbird.io",
|
||||
"phone": "",
|
||||
"description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.",
|
||||
"webpageUrl": {
|
||||
"url": "https://github.com/netbirdio"
|
||||
}
|
||||
},
|
||||
"projects": [
|
||||
{
|
||||
"guid": "netbird",
|
||||
"name": "NetBird",
|
||||
"description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.",
|
||||
"webpageUrl": {
|
||||
"url": "https://github.com/netbirdio/netbird"
|
||||
},
|
||||
"repositoryUrl": {
|
||||
"url": "https://github.com/netbirdio/netbird"
|
||||
},
|
||||
"licenses": [
|
||||
"BSD-3"
|
||||
],
|
||||
"tags": [
|
||||
"network-security",
|
||||
"vpn",
|
||||
"developer-tools",
|
||||
"ztna",
|
||||
"zero-trust",
|
||||
"remote-access",
|
||||
"wireguard",
|
||||
"peer-to-peer",
|
||||
"private-networking",
|
||||
"software-defined-networking"
|
||||
]
|
||||
}
|
||||
],
|
||||
"funding": {
|
||||
"channels": [
|
||||
{
|
||||
"guid": "github-sponsors",
|
||||
"type": "payment-provider",
|
||||
"address": "https://github.com/sponsors/netbirdio",
|
||||
"description": ""
|
||||
},
|
||||
{
|
||||
"guid": "bank-transfer",
|
||||
"type": "bank",
|
||||
"address": "",
|
||||
"description": "Contact us at hello@netbird.io for bank transfer details."
|
||||
}
|
||||
],
|
||||
"plans": [
|
||||
{
|
||||
"guid": "support-yearly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - Yearly",
|
||||
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 100000,
|
||||
"currency": "USD",
|
||||
"frequency": "yearly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-one-time-year",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - One Year",
|
||||
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 100000,
|
||||
"currency": "USD",
|
||||
"frequency": "one-time",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-one-time-monthly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - Monthly",
|
||||
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 10000,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-monthly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - One Month",
|
||||
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 10000,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "goodwill",
|
||||
"status": "active",
|
||||
"name": "Goodwill Plan",
|
||||
"description": "Pay anything you wish to show your goodwill for the project.",
|
||||
"amount": 0,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
}
|
||||
],
|
||||
"history": null
|
||||
}
|
||||
}
|
||||
11
go.mod
11
go.mod
@@ -60,7 +60,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@@ -71,7 +71,6 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/r3labs/diff/v3 v3.0.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@@ -156,7 +155,7 @@ require (
|
||||
github.com/go-text/typesetting v0.1.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
@@ -211,8 +210,6 @@ require (
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
@@ -231,7 +228,7 @@ require (
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
|
||||
k8s.io/apimachinery v0.26.2 // indirect
|
||||
)
|
||||
|
||||
@@ -239,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
22
go.sum
22
go.sum
@@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
@@ -521,14 +521,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
@@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -1238,8 +1232,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,14 +29,18 @@ import (
|
||||
)
|
||||
|
||||
type MocIntegratedValidator struct {
|
||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
return update, nil
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
if a.ValidatePeerFunc != nil {
|
||||
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||
}
|
||||
return update, false, nil
|
||||
}
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
@@ -397,7 +401,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
store := newStore(t)
|
||||
|
||||
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "account-1")
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -415,6 +426,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
|
||||
store.Close(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -422,27 +435,35 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
||||
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
func TestAccountManager_GetOrCreateAccountIDByUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userID)
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
account, err := manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
|
||||
return
|
||||
@@ -665,15 +686,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountIDByUserID where the id is getting generated
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
@@ -689,44 +707,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
require.Len(t, accountGroups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
|
||||
require.NoError(t, err, "failed to update account settings")
|
||||
|
||||
totalAccounts, err := manager.Store.GetTotalAccounts(context.Background())
|
||||
require.NoError(t, err, "failed to get total accounts")
|
||||
require.Equal(t, int64(1), totalAccounts, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
|
||||
require.NoError(t, err, "failed to update account settings")
|
||||
|
||||
totalAccounts, err := manager.Store.GetTotalAccounts(context.Background())
|
||||
require.NoError(t, err, "failed to get total accounts")
|
||||
require.Equal(t, int64(1), totalAccounts, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to check account existence")
|
||||
require.True(t, exists, "account should exist")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
require.Len(t, accountGroups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*group.Group{}
|
||||
for _, g := range account.Groups {
|
||||
for _, g := range accountGroups {
|
||||
groupsByNames[g.Name] = g
|
||||
}
|
||||
|
||||
@@ -742,62 +769,55 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
func TestAccountManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
|
||||
userPAT := &PersonalAccessToken{
|
||||
ID: "tokenId",
|
||||
UserID: "testuser",
|
||||
HashedToken: encodedHashedToken,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
|
||||
require.NoError(t, err, "failed to save PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
||||
user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", account.Id)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "testuser", user.Id)
|
||||
assert.Equal(t, userPAT, pat)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
|
||||
userPAT := &PersonalAccessToken{
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
|
||||
require.NoError(t, err, "failed to save PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -808,11 +828,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
|
||||
userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
|
||||
require.NoError(t, err, "failed to get PAT")
|
||||
|
||||
assert.True(t, !userPAT.LastUsed.IsZero())
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
@@ -823,15 +842,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
account, err := manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
|
||||
}
|
||||
@@ -850,32 +869,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
|
||||
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
|
||||
}
|
||||
accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account domain and category")
|
||||
require.Equal(t, domain, accDomain, "expected account domain to match")
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
|
||||
if account == nil {
|
||||
t.Fatalf("expected to get an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
|
||||
}
|
||||
accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account domain and category")
|
||||
require.Equal(t, domain, accDomain, "expected account domain to match")
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
@@ -907,12 +916,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return am.Store.GetAccount(context.Background(), accountID)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
@@ -978,6 +986,110 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "example.com",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
}
|
||||
|
||||
publicClaims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "test.com",
|
||||
UserId: "public-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
}
|
||||
|
||||
am, err := createManager(b)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
return
|
||||
}
|
||||
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
users := genUsers("priv", 100)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), id)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
acc.Users = users
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), acc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
userP := genUsers("pub", 100)
|
||||
|
||||
pacc, err := am.Store.GetAccount(context.Background(), pid)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pacc.Users = userP
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), pacc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.Run("public without account ID", func(b *testing.B) {
|
||||
//b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("private without account ID", func(b *testing.B) {
|
||||
//b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("private with account ID", func(b *testing.B) {
|
||||
claims.AccountId = id
|
||||
//b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func genUsers(p string, n int) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
now := time.Now()
|
||||
for i := 0; i < n; i++ {
|
||||
users[fmt.Sprintf("%s-%d", p, i)] = &User{
|
||||
Id: fmt.Sprintf("%s-%d", p, i),
|
||||
Role: UserRoleAdmin,
|
||||
LastLogin: now,
|
||||
CreatedAt: now,
|
||||
Issued: "api",
|
||||
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -1055,23 +1167,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account network")
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
return
|
||||
}
|
||||
serial := network.CurrentSerial() // should be 0
|
||||
require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to generate private key")
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
@@ -1079,16 +1186,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to add peer")
|
||||
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
if peer.Key != expectedPeerKey {
|
||||
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
|
||||
@@ -1130,8 +1231,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1142,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1212,19 +1311,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1237,7 +1323,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1258,7 +1356,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1269,9 +1367,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1305,13 +1402,20 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1322,14 +1426,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1352,7 +1450,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1367,7 +1465,6 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
userID := "account_creator"
|
||||
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1396,7 +1493,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer.ID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -1418,7 +1515,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
assert.Equal(t, peer.Name, ev.Meta["name"])
|
||||
assert.Equal(t, peer.FQDN(account.Domain), ev.Meta["fqdn"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
assert.Equal(t, peer.IP.String(), ev.TargetID)
|
||||
assert.Equal(t, peer.ID, ev.TargetID)
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
@@ -1748,16 +1845,15 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
settings.PeerLoginExpirationEnabled = true
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -1774,11 +1870,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1802,10 +1898,14 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
|
||||
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
settings.PeerLoginExpirationEnabled = true
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -1822,11 +1922,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1857,7 +1954,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -1870,11 +1967,15 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
settings.PeerLoginExpirationEnabled = true
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1884,10 +1985,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
wg.Add(1)
|
||||
|
||||
// disabling PeerLoginExpirationEnabled should trigger cancel
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
settings.PeerLoginExpirationEnabled = false
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
failed = waitTimeout(wg, time.Second)
|
||||
if failed {
|
||||
@@ -1902,30 +2001,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
settings.PeerLoginExpirationEnabled = false
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
settings.PeerLoginExpiration = time.Second
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
settings.PeerLoginExpiration = time.Hour * 24 * 181
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||
}
|
||||
|
||||
@@ -2606,7 +2704,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2626,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2665,7 +2763,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package differs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
|
||||
"github.com/r3labs/diff/v3"
|
||||
)
|
||||
|
||||
// NetIPAddr is a custom differ for netip.Addr
|
||||
type NetIPAddr struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromAddr.String() != toAddr.String() {
|
||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
|
||||
// NetIPPrefix is a custom differ for netip.Prefix
|
||||
type NetIPPrefix struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromPrefix.String() != toPrefix.String() {
|
||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
||||
}
|
||||
|
||||
if dnsSettingsToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
if !user.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range addedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
for _, groupID := range removedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
|
||||
@@ -8,9 +8,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -38,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
accountID, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Fatal("failed to init testing account")
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@@ -52,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Fatal("DNS settings for new accounts shouldn't return nil")
|
||||
}
|
||||
|
||||
account.DNSSettings = DNSSettings{
|
||||
err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{group1ID},
|
||||
}
|
||||
})
|
||||
require.NoError(t, err, "failed to update DNS settings")
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save testing account with new DNS settings")
|
||||
}
|
||||
|
||||
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@@ -70,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
|
||||
}
|
||||
|
||||
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
|
||||
_, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID)
|
||||
if err == nil {
|
||||
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
|
||||
}
|
||||
@@ -125,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
accountID, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
|
||||
err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
@@ -138,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
updatedAccount, err := am.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Errorf("should be able to retrieve updated account, got err: %s", err)
|
||||
}
|
||||
@@ -157,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
accountID, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
peer1, err := account.FindPeerByPubKey(dnsPeer1Key)
|
||||
peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
peer2, err := account.FindPeerByPubKey(dnsPeer2Key)
|
||||
peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@@ -178,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account DNS settings")
|
||||
|
||||
dnsSettings := accountDNSSettings.Copy()
|
||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||
account.DNSSettings = dnsSettings
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings)
|
||||
require.NoError(t, err, "failed to update DNS settings")
|
||||
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
require.NoError(t, err)
|
||||
@@ -221,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
|
||||
t.Helper()
|
||||
peer1 := &nbpeer.Peer{
|
||||
Key: dnsPeer1Key,
|
||||
@@ -256,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
|
||||
account.Users[dnsRegularUserID] = &User{
|
||||
Id: dnsRegularUserID,
|
||||
Role: UserRoleUser,
|
||||
err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: dnsRegularUserID,
|
||||
AccountID: dnsAccountID,
|
||||
Role: UserRoleUser,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer1, err = account.FindPeerByPubKey(peer1.Key)
|
||||
_, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = account.FindPeerByPubKey(peer2.Key)
|
||||
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{
|
||||
{
|
||||
ID: dnsGroup1ID,
|
||||
AccountID: dnsAccountID,
|
||||
Peers: []string{peer1.ID},
|
||||
Name: dnsGroup1ID,
|
||||
},
|
||||
{
|
||||
ID: dnsGroup2ID,
|
||||
AccountID: dnsAccountID,
|
||||
Name: dnsGroup2ID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
newGroup1 := &group.Group{
|
||||
ID: dnsGroup1ID,
|
||||
Peers: []string{peer1.ID},
|
||||
Name: dnsGroup1ID,
|
||||
}
|
||||
|
||||
newGroup2 := &group.Group{
|
||||
ID: dnsGroup2ID,
|
||||
Name: dnsGroup2ID,
|
||||
}
|
||||
|
||||
account.Groups[newGroup1.ID] = newGroup1
|
||||
account.Groups[newGroup2.ID] = newGroup2
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
|
||||
ID: dnsNSGroup1,
|
||||
Name: "ns-group-1",
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{
|
||||
ID: dnsNSGroup1,
|
||||
AccountID: dnsAccountID,
|
||||
Name: "ns-group-1",
|
||||
NameServers: []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(savedPeer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
@@ -322,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
Primary: true,
|
||||
Enabled: true,
|
||||
Groups: []string{allGroup.ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
return dnsAccountID, nil
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
@@ -521,23 +519,64 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
|
||||
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
@@ -559,27 +598,6 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -20,10 +20,10 @@ var (
|
||||
)
|
||||
|
||||
type ephemeralPeer struct {
|
||||
id string
|
||||
account *Account
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
id string
|
||||
accountID string
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
}
|
||||
|
||||
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||
@@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
@@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
return
|
||||
}
|
||||
|
||||
e.addPeer(peer.ID, a, newDeadLine())
|
||||
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
@@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
accounts := e.store.GetAllAccounts(context.Background())
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
t := newDeadLine()
|
||||
count := 0
|
||||
for _, a := range accounts {
|
||||
for id, p := range a.Peers {
|
||||
if p.Ephemeral {
|
||||
count++
|
||||
e.addPeer(id, a, t)
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.Ephemeral {
|
||||
count++
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
|
||||
}
|
||||
|
||||
@@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
|
||||
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
ep := &ephemeralPeer{
|
||||
id: id,
|
||||
account: account,
|
||||
deadline: deadline,
|
||||
id: peerID,
|
||||
accountID: accountID,
|
||||
deadline: deadline,
|
||||
}
|
||||
|
||||
if e.headPeer == nil {
|
||||
|
||||
@@ -7,25 +7,12 @@ import (
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MockStore struct {
|
||||
Store
|
||||
account *Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
|
||||
return []*Account{s.account}
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
|
||||
_, ok := s.account.Peers[peerId]
|
||||
if ok {
|
||||
return s.account, nil
|
||||
}
|
||||
|
||||
return nil, status.NewPeerNotFoundError(peerId)
|
||||
accountID string
|
||||
}
|
||||
|
||||
type MocAccountManager struct {
|
||||
@@ -33,9 +20,8 @@ type MocAccountManager struct {
|
||||
store *MockStore
|
||||
}
|
||||
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
delete(a.store.account.Peers, peerID)
|
||||
return nil //nolint:nil
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error {
|
||||
return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
@@ -44,23 +30,26 @@ func TestNewManager(t *testing.T) {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
store := &MockStore{
|
||||
Store: newStore(t),
|
||||
}
|
||||
am := MocAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
require.NoError(t, err, "failed to seed peers")
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
if len(store.account.Peers) != numberOfPeers {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
|
||||
}
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers")
|
||||
}
|
||||
|
||||
func TestNewManagerPeerConnected(t *testing.T) {
|
||||
@@ -69,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
store := &MockStore{
|
||||
Store: newStore(t),
|
||||
}
|
||||
am := MocAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
require.NoError(t, err, "failed to seed peers")
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
|
||||
require.NoError(t, err, "failed to get peer")
|
||||
|
||||
mgr.OnPeerConnected(context.Background(), peer)
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers")
|
||||
}
|
||||
|
||||
func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
@@ -97,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
store := &MockStore{
|
||||
Store: newStore(t),
|
||||
}
|
||||
am := MocAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
require.NoError(t, err, "failed to seed peers")
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
for _, v := range peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
}
|
||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
|
||||
require.NoError(t, err, "failed to get peer")
|
||||
mgr.OnPeerDisconnected(context.Background(), peer)
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers")
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error {
|
||||
accountID := "my account"
|
||||
err := newAccountWithId(context.Background(), store, accountID, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store.accountID = accountID
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
AccountID: accountID,
|
||||
Ephemeral: false,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
err = store.AddPeerToAccount(context.Background(), p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < numberOfEphemeralPeers; i++ {
|
||||
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
AccountID: accountID,
|
||||
Ephemeral: true,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
err = store.AddPeerToAccount(context.Background(), p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -49,8 +53,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -58,13 +61,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccountGroups(ctx, accountID)
|
||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@@ -77,79 +79,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
// SaveGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
for _, newGroup := range newGroups {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
return err
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Avoid duplicate groups only for the API issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldGroup := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
newGroupIDs := make([]string, 0, len(newGroups))
|
||||
for _, newGroup := range newGroups {
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
if oldGroup != nil {
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
@@ -159,35 +156,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
|
||||
continue
|
||||
}
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
for _, peerID := range removedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
|
||||
continue
|
||||
}
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -210,40 +214,47 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var group *nbgroup.Group
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -254,93 +265,90 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*nbgroup.Group
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
for _, group := range deletedGroups {
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
for _, item := range account.Groups {
|
||||
groups = append(groups, item)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
add := true
|
||||
for _, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
add = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if add {
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -348,41 +356,80 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
// validateNewGroup validates the new group for existence and required fields.
|
||||
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent duplicate groups for API-issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
@@ -390,51 +437,77 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
|
||||
}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||
}
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
@@ -446,7 +519,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
@@ -454,11 +533,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
@@ -468,7 +554,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
@@ -477,31 +569,47 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
}
|
||||
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
|
||||
func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
||||
// IsGroupAll checks if the group is a default "All" group.
|
||||
func (g *Group) IsGroupAll() bool {
|
||||
return g.Name == "All"
|
||||
}
|
||||
|
||||
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||
func (g *Group) AddPeer(peerID string) bool {
|
||||
if peerID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
g.Peers = append(g.Peers, peerID)
|
||||
return true
|
||||
}
|
||||
|
||||
// RemovePeer removes peerID from Peers if present, returning true if removed.
|
||||
func (g *Group) RemovePeer(peerID string) bool {
|
||||
for i, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
90
management/server/group/group_test.go
Normal file
90
management/server/group/group_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddPeer(t *testing.T) {
|
||||
t.Run("add new peer to empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to non-empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add duplicate peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("add empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
t.Run("remove existing peer from slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
|
||||
peerID := "peer2"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Nil(t, group.Peers)
|
||||
})
|
||||
|
||||
t.Run("remove non-existent peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from single-item slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1"}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("remove empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -327,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
}
|
||||
|
||||
routeResource := &route.Route{
|
||||
ID: "example route",
|
||||
Groups: []string{groupForRoute.ID},
|
||||
ID: "example route",
|
||||
AccountID: accountID,
|
||||
Groups: []string{groupForRoute.ID},
|
||||
}
|
||||
|
||||
routePeerGroupResource := &route.Route{
|
||||
ID: "example route with peer groups",
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupForRoute2.ID},
|
||||
}
|
||||
|
||||
nameServerGroup := &nbdns.NameServerGroup{
|
||||
ID: "example name server group",
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
ID: "example name server group",
|
||||
AccountID: accountID,
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
ID: "example policy",
|
||||
ID: "example policy",
|
||||
AccountID: accountID,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "example policy rule",
|
||||
PolicyID: "example policy",
|
||||
Destinations: []string{groupForPolicies.ID},
|
||||
},
|
||||
},
|
||||
@@ -353,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
|
||||
setupKey := &SetupKey{
|
||||
Id: "example setup key",
|
||||
AccountID: accountID,
|
||||
AutoGroups: []string{groupForSetupKeys.ID},
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Id: "example user",
|
||||
AccountID: accountID,
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
account.Policies = append(account.Policies, policy)
|
||||
account.SetupKeys[setupKey.Id] = setupKey
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{
|
||||
groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies,
|
||||
groupForSetupKeys, groupForUsers, groupForIntegration,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -499,8 +530,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -511,7 +541,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving a group linked to policy should update account peers and send peer update
|
||||
@@ -536,29 +566,6 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding peer to a used group should update account peers and send peer update
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -38,6 +39,7 @@ type GRPCServer struct {
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
@@ -171,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -190,11 +200,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@@ -245,10 +259,18 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
}
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
@@ -274,6 +296,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
|
||||
@@ -100,13 +100,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
|
||||
resp := toAccountResponse(accountID, updatedSettings)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -39,9 +39,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
accCopy := account.Copy()
|
||||
accCopy.UpdateSettings(newSettings)
|
||||
return accCopy, nil
|
||||
return newSettings.Copy(), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
||||
@@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
accountManager.GetAccountInfoFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
@@ -47,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
|
||||
@@ -33,7 +33,8 @@ var testAccount = &server.Account{
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
Id: userID,
|
||||
AccountID: accountID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
@@ -49,11 +50,11 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
@@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountFromPAT,
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
|
||||
@@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
return peerToReturn, nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
@@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
groupsInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
@@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
@@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
return
|
||||
case http.MethodGet, http.MethodPut:
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||
} else {
|
||||
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||
}
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), accountID, peerID, userID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
@@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
@@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsChecked := make(map[string]struct{})
|
||||
func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
|
||||
groupsInfo := make([]api.GroupMinimum, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
groupsChecked[group.ID] = struct{}{}
|
||||
for _, pk := range group.Peers {
|
||||
if pk == peerID {
|
||||
info := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
}
|
||||
groupsInfo = append(groupsInfo, info)
|
||||
break
|
||||
}
|
||||
}
|
||||
groupsInfo = append(groupsInfo, api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
})
|
||||
}
|
||||
return groupsInfo
|
||||
}
|
||||
|
||||
@@ -39,6 +39,68 @@ const (
|
||||
)
|
||||
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: "test_id",
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
Enabled: true,
|
||||
Action: "accept",
|
||||
Destinations: []string{"group1"},
|
||||
Sources: []string{"group1"},
|
||||
Bidirectional: true,
|
||||
Protocol: "all",
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: "test_id",
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: "test_id",
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
IP: net.ParseIP("100.67.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
Serial: 51,
|
||||
},
|
||||
}
|
||||
|
||||
return &PeersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
@@ -67,74 +129,31 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||
peersID := make([]string, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersID = append(peersID, peer.ID)
|
||||
}
|
||||
return []*nbgroup.Group{
|
||||
{
|
||||
ID: "group1",
|
||||
AccountID: accountID,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: peersID,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) {
|
||||
return account, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: accountID,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
Enabled: true,
|
||||
Action: "accept",
|
||||
Destinations: []string{"group1"},
|
||||
Sources: []string{"group1"},
|
||||
Bidirectional: true,
|
||||
Protocol: "all",
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: accountID,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
IP: net.ParseIP("100.67.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
Serial: 51,
|
||||
},
|
||||
}
|
||||
|
||||
return account, nil
|
||||
},
|
||||
HasConnectedChannelFunc: func(peerID string) bool {
|
||||
|
||||
@@ -6,10 +6,8 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -122,14 +120,9 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
return
|
||||
}
|
||||
|
||||
isUpdate := policyID != ""
|
||||
|
||||
if policyID == "" {
|
||||
policyID = xid.New().String()
|
||||
}
|
||||
|
||||
policy := server.Policy{
|
||||
policy := &server.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
@@ -137,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
for _, rule := range req.Rules {
|
||||
pr := server.PolicyRule{
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
PolicyID: policyID,
|
||||
Name: rule.Name,
|
||||
Destinations: rule.Destinations,
|
||||
Sources: rule.Sources,
|
||||
@@ -225,7 +219,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||
}
|
||||
|
||||
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -236,7 +231,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(allGroups, &policy)
|
||||
resp := toPolicyResponse(allGroups, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
|
||||
@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
return policy, nil
|
||||
},
|
||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||
|
||||
@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
return nil
|
||||
return postureChecks, nil
|
||||
},
|
||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||
_, ok := testPostureChecks[postureChecksID]
|
||||
|
||||
@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
}
|
||||
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -52,30 +54,60 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if len(groups) == 0 {
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, group := range groups {
|
||||
var found bool
|
||||
for _, accountGroup := range accountsGroups {
|
||||
if accountGroup.ID == group {
|
||||
found = true
|
||||
break
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false, nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
|
||||
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
var err error
|
||||
var groups []*nbgroup.Group
|
||||
var peers []*nbpeer.Peer
|
||||
var settings *Settings
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err = transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||
type IntegratedValidator interface {
|
||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
||||
|
||||
@@ -453,8 +453,8 @@ func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtr
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
return update, nil
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||
|
||||
@@ -22,9 +22,9 @@ import (
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
@@ -45,16 +45,16 @@ type MockAccountManager struct {
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -89,15 +89,15 @@ type MockAccountManager struct {
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
@@ -131,7 +131,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
account, err := am.GetAccountFunc(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
approvedPeers[id] = struct{}{}
|
||||
@@ -171,16 +176,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
// GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountIDByUser(
|
||||
ctx context.Context, userId, domain string,
|
||||
) (*server.Account, error) {
|
||||
if am.GetOrCreateAccountByUserFunc != nil {
|
||||
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
|
||||
) (string, error) {
|
||||
if am.GetOrCreateAccountIDByUserFunc != nil {
|
||||
return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
return "", status.Errorf(
|
||||
codes.Unimplemented,
|
||||
"method GetOrCreateAccountByUser is not implemented",
|
||||
"method GetOrCreateAccountIDByUser is not implemented",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -222,19 +227,19 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, pat string) (*server.User, *server.PersonalAccessToken, string, string, error) {
|
||||
if am.GetAccountInfoFromPATFunc != nil {
|
||||
return am.GetAccountInfoFromPATFunc(ctx, pat)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
@@ -354,14 +359,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
return am.ListGroupsFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
||||
}
|
||||
|
||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
if am.GroupAddPeerFunc != nil {
|
||||
@@ -395,11 +392,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
||||
}
|
||||
|
||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
|
||||
if am.SavePolicyFunc != nil {
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||
}
|
||||
|
||||
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
||||
@@ -675,7 +672,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
|
||||
}
|
||||
@@ -691,9 +688,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(ctx, sync, account)
|
||||
return am.SyncPeerFunc(ctx, sync, accountID)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||
}
|
||||
@@ -739,11 +736,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
||||
}
|
||||
|
||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
if am.SavePostureChecksFunc != nil {
|
||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||
}
|
||||
|
||||
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
||||
@@ -840,3 +837,11 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
|
||||
if am.GetPeerGroupsFunc != nil {
|
||||
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
||||
}
|
||||
|
||||
@@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
newNSGroup := &nbdns.NameServerGroup{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
NameServers: nameServerList,
|
||||
@@ -54,27 +62,34 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
SearchDomainsEnabled: searchDomainEnabled,
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(false, newNSGroup, account)
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.NameServerGroups == nil {
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
|
||||
}
|
||||
|
||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
}
|
||||
|
||||
@@ -87,59 +102,96 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nsGroupToSave.AccountID = accountID
|
||||
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroup := account.NameServerGroups[nsGroupID]
|
||||
if nsGroup == nil {
|
||||
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||
nsGroupID := ""
|
||||
if existingGroup {
|
||||
nsGroupID = nameserverGroup.ID
|
||||
_, found := account.NameServerGroups[nsGroupID]
|
||||
if !found {
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSList(nameserverGroup.NameServers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(nameserverGroup.Groups, account.Groups)
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validateGroups(nameserverGroup.Groups, groups)
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||
@@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
|
||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||
}
|
||||
|
||||
for _, nsGroup := range nsGroupMap {
|
||||
for _, nsGroup := range groups {
|
||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
||||
}
|
||||
|
||||
func validateNSList(list []nbdns.NameServer) error {
|
||||
nsListLenght := len(list)
|
||||
if nsListLenght == 0 || nsListLenght > 3 {
|
||||
nsListLength := len(list)
|
||||
if nsListLength == 0 || nsListLength > 3 {
|
||||
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
||||
}
|
||||
return nil
|
||||
@@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
||||
if id == "" {
|
||||
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||
}
|
||||
found := false
|
||||
for groupID := range groups {
|
||||
if id == groupID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, found := groups[id]; !found {
|
||||
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||
}
|
||||
}
|
||||
@@ -277,11 +340,3 @@ func validateDomain(domain string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
outNSGroup, err := am.CreateNameServerGroup(
|
||||
context.Background(),
|
||||
account.Id,
|
||||
accountID,
|
||||
testCase.inputArgs.name,
|
||||
testCase.inputArgs.description,
|
||||
testCase.inputArgs.nameServers,
|
||||
@@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
testCase.existingNSGroup.AccountID = accountID
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup)
|
||||
require.NoError(t, err, "failed to save existing nameserver group")
|
||||
|
||||
var nsGroupToSave *nbdns.NameServerGroup
|
||||
|
||||
if !testCase.skipCopying {
|
||||
nsGroupToSave = testCase.existingNSGroup.Copy()
|
||||
|
||||
@@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave)
|
||||
|
||||
err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave)
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
if !testCase.shouldCreate {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
|
||||
require.True(t, saved)
|
||||
savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID)
|
||||
require.NoError(t, err, "failed to get saved nameserver group")
|
||||
|
||||
testCase.expectedNSGroup.AccountID = accountID
|
||||
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
|
||||
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
|
||||
}
|
||||
@@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
|
||||
testingNSGroup.AccountID = accountID
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup)
|
||||
require.NoError(t, err, "failed to save nameserver group")
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save account")
|
||||
}
|
||||
|
||||
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
|
||||
err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID)
|
||||
if err != nil {
|
||||
t.Error("deleting nameserver group failed with error: ", err)
|
||||
}
|
||||
|
||||
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Error("failed to retrieve saved account with error: ", err)
|
||||
}
|
||||
|
||||
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
|
||||
if found {
|
||||
t.Error("nameserver group shouldn't be found after delete")
|
||||
}
|
||||
_, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID)
|
||||
require.NotNil(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error should be a status error")
|
||||
assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete")
|
||||
}
|
||||
|
||||
func TestGetNameServerGroup(t *testing.T) {
|
||||
@@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
|
||||
foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID)
|
||||
if err != nil {
|
||||
t.Error("getting existing nameserver group failed with error: ", err)
|
||||
}
|
||||
@@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
t.Error("got a nil group while getting nameserver group with ID")
|
||||
}
|
||||
|
||||
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing")
|
||||
_, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing")
|
||||
if err == nil {
|
||||
t.Error("getting not existing nameserver group should return error, got nil")
|
||||
}
|
||||
@@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
|
||||
t.Helper()
|
||||
accountID := "testingAcc"
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
Key: nsGroupPeer1Key,
|
||||
Name: "test-host1@netbird.io",
|
||||
@@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
|
||||
}
|
||||
existingNSGroup := nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
AccountID: accountID,
|
||||
Name: existingNSGroupName,
|
||||
Description: "",
|
||||
NameServers: []nbdns.NameServer{
|
||||
@@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
accountID := "testingAcc"
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
newGroup1 := &nbgroup.Group{
|
||||
ID: group1ID,
|
||||
Name: group1ID,
|
||||
}
|
||||
|
||||
newGroup2 := &nbgroup.Group{
|
||||
ID: group2ID,
|
||||
Name: group2ID,
|
||||
}
|
||||
|
||||
account.Groups[newGroup1.ID] = newGroup1
|
||||
account.Groups[newGroup2.ID] = newGroup2
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{
|
||||
{
|
||||
ID: group1ID,
|
||||
AccountID: accountID,
|
||||
Name: group1ID,
|
||||
},
|
||||
{
|
||||
ID: group2ID,
|
||||
AccountID: accountID,
|
||||
Name: group2ID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
@@ -1065,36 +1055,6 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// saving unchanged nameserver group should update account peers and not send peer update
|
||||
t.Run("saving unchanged nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting a nameserver group should update account peers and send peer update
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -41,9 +41,9 @@ type Network struct {
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64 `diff:"-"`
|
||||
Serial uint64
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -19,33 +20,33 @@ type Peer struct {
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
// The user ID that registered the peer
|
||||
UserID string `diff:"-"`
|
||||
UserID string
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
LoginExpirationEnabled bool
|
||||
|
||||
InactivityExpirationEnabled bool `diff:"-"`
|
||||
InactivityExpirationEnabled bool
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time `diff:"-"`
|
||||
LastLogin time.Time
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
CreatedAt time.Time
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool `diff:"-"`
|
||||
Ephemeral bool `gorm:"index"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
@@ -107,6 +108,12 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
|
||||
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
|
||||
})
|
||||
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
|
||||
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
|
||||
})
|
||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
||||
})
|
||||
@@ -114,6 +121,12 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.Files, func(i, j int) bool {
|
||||
return p.Files[i].Path < p.Files[j].Path
|
||||
})
|
||||
sort.Slice(other.Files, func(i, j int) bool {
|
||||
return other.Files[i].Path < other.Files[j].Path
|
||||
})
|
||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -29,3 +30,56 @@ func BenchmarkFQDN(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsEqual(t *testing.T) {
|
||||
meta1 := PeerSystemMeta{
|
||||
NetworkAddresses: []NetworkAddress{{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.2/24"),
|
||||
Mac: "2",
|
||||
},
|
||||
{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Mac: "1",
|
||||
},
|
||||
},
|
||||
Files: []File{
|
||||
{
|
||||
Path: "/etc/hosts1",
|
||||
Exist: true,
|
||||
ProcessIsRunning: true,
|
||||
},
|
||||
{
|
||||
Path: "/etc/hosts2",
|
||||
Exist: false,
|
||||
ProcessIsRunning: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
meta2 := PeerSystemMeta{
|
||||
NetworkAddresses: []NetworkAddress{
|
||||
{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Mac: "1",
|
||||
},
|
||||
{
|
||||
NetIP: netip.MustParsePrefix("192.168.1.2/24"),
|
||||
Mac: "2",
|
||||
},
|
||||
},
|
||||
Files: []File{
|
||||
{
|
||||
Path: "/etc/hosts2",
|
||||
Exist: false,
|
||||
ProcessIsRunning: false,
|
||||
},
|
||||
{
|
||||
Path: "/etc/hosts1",
|
||||
Exist: true,
|
||||
ProcessIsRunning: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
if !meta1.isEqual(meta2) {
|
||||
t.Error("meta1 should be equal to meta2")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -282,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
var (
|
||||
group1 nbgroup.Group
|
||||
group2 nbgroup.Group
|
||||
policy Policy
|
||||
)
|
||||
|
||||
group1.ID = xid.New().String()
|
||||
group2.ID = xid.New().String()
|
||||
group1.Name = "src"
|
||||
group2.Name = "dst"
|
||||
policy.ID = xid.New().String()
|
||||
group1.Peers = append(group1.Peers, peer1.ID)
|
||||
group2.Peers = append(group2.Peers, peer2.ID)
|
||||
|
||||
@@ -304,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy.Name = "test"
|
||||
policy.Enabled = true
|
||||
policy.Rules = []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{group1.ID},
|
||||
Destinations: []string{group2.ID},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
policy := &Policy{
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{group1.ID},
|
||||
Destinations: []string{group2.ID},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
@@ -363,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
policy.Enabled = false
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
@@ -467,21 +468,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
Id: someUser,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
account.Settings.RegularUsersViewBlocked = false
|
||||
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: someUser,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleUser,
|
||||
})
|
||||
require.NoError(t, err, "failed to create user")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.RegularUsersViewBlocked = false
|
||||
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
|
||||
require.NoError(t, err, "failed to save account settings")
|
||||
|
||||
// two peers one added by a regular user and one with a setup key
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -535,7 +540,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
assert.NotNil(t, peer)
|
||||
|
||||
// delete the all-to-all policy so that user's peer1 has no access to peer2
|
||||
for _, policy := range account.Policies {
|
||||
accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account policies")
|
||||
|
||||
for _, policy := range accountPolicies {
|
||||
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -654,21 +662,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
|
||||
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: someUser,
|
||||
AccountID: accountID,
|
||||
Role: testCase.role,
|
||||
IsServiceUser: testCase.isServiceUser,
|
||||
}
|
||||
account.Policies = []*Policy{}
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
})
|
||||
require.NoError(t, err, "failed to create user")
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account policies")
|
||||
|
||||
for _, policy := range accountPolicies {
|
||||
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
|
||||
require.NoError(t, err, "failed to delete policy")
|
||||
}
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
|
||||
require.NoError(t, err, "failed to save account settings")
|
||||
|
||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -724,10 +744,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &User{
|
||||
Id: regularUser,
|
||||
Role: UserRoleUser,
|
||||
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: regularUser,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleUser,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
// Create peers
|
||||
@@ -741,31 +769,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
UserID: regularUser,
|
||||
}
|
||||
account.Peers[peer.ID] = peer
|
||||
err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
account.Policies = make([]*Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &nbgroup.Group{
|
||||
ID: groupID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
ID: groupID,
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
}
|
||||
for j := 0; j < peers/groups; j++ {
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
account.Groups[groupID] = group
|
||||
|
||||
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &Policy{
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: fmt.Sprintf("rule-%d", i),
|
||||
PolicyID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||
Enabled: true,
|
||||
Sources: []string{groupID},
|
||||
@@ -776,22 +813,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
},
|
||||
},
|
||||
}
|
||||
account.Policies = append(account.Policies, policy)
|
||||
|
||||
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{
|
||||
ID: "PostureChecksAll",
|
||||
AccountID: accountID,
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
@@ -876,7 +914,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
manager.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
@@ -1398,11 +1436,55 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validator requires update", func(t *testing.T) {
|
||||
requireUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, true, nil
|
||||
}
|
||||
|
||||
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validator requires no update", func(t *testing.T) {
|
||||
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
@@ -1412,7 +1494,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -41,6 +41,7 @@ type PersonalAccessToken struct {
|
||||
func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
|
||||
return &PersonalAccessToken{
|
||||
ID: t.ID,
|
||||
UserID: t.UserID,
|
||||
Name: t.Name,
|
||||
HashedToken: t.HashedToken,
|
||||
ExpirationDate: t.ExpirationDate,
|
||||
@@ -58,7 +59,7 @@ type PersonalAccessTokenGenerated struct {
|
||||
|
||||
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateNewToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -67,6 +68,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
|
||||
return &PersonalAccessTokenGenerated{
|
||||
PersonalAccessToken: PersonalAccessToken{
|
||||
ID: xid.New().String(),
|
||||
UserID: targetID,
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
|
||||
|
||||
@@ -3,13 +3,13 @@ package server
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -125,6 +125,7 @@ type PolicyRule struct {
|
||||
func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
rule := &PolicyRule{
|
||||
ID: pm.ID,
|
||||
PolicyID: pm.PolicyID,
|
||||
Name: pm.Name,
|
||||
Description: pm.Description,
|
||||
Enabled: pm.Enabled,
|
||||
@@ -171,6 +172,7 @@ type Policy struct {
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
@@ -343,44 +345,72 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
|
||||
return saveFunc(ctx, LockingStrengthUpdate, policy)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action := activity.PolicyAdded
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
}
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// DeletePolicy from the store
|
||||
@@ -388,110 +418,136 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy, err := am.deletePolicy(account, policyID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var policy *Policy
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPolicies from the store
|
||||
// ListPolicies from the store.
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||
policyIdx := -1
|
||||
for i, policy := range account.Policies {
|
||||
if policy.ID == policyID {
|
||||
policyIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if policyIdx < 0 {
|
||||
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
|
||||
}
|
||||
|
||||
policy := account.Policies[policyIdx]
|
||||
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// savePolicy saves or updates a policy in the given account.
|
||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||
for index, rule := range policyToSave.Rules {
|
||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||
policyToSave.Rules[index] = rule
|
||||
}
|
||||
|
||||
if policyToSave.SourcePostureChecks != nil {
|
||||
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||
if policyIdx < 0 {
|
||||
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldPolicy := account.Policies[policyIdx]
|
||||
// Update the existing policy
|
||||
account.Policies[policyIdx] = policyToSave
|
||||
|
||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||
|
||||
return updateAccountPeers, nil
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||
}
|
||||
|
||||
// Add the new policy to the account
|
||||
account.Policies = append(account.Policies, policyToSave)
|
||||
|
||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
|
||||
if policy.ID != "" {
|
||||
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
policy.ID = xid.New().String()
|
||||
policy.AccountID = accountID
|
||||
}
|
||||
return result
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, rule := range policy.Rules {
|
||||
ruleCopy := rule.Copy()
|
||||
if ruleCopy.ID == "" {
|
||||
ruleCopy.ID = xid.New().String()
|
||||
ruleCopy.PolicyID = policy.ID
|
||||
}
|
||||
|
||||
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
|
||||
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
|
||||
policy.Rules[i] = ruleCopy
|
||||
}
|
||||
|
||||
if policy.SourcePostureChecks != nil {
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
@@ -572,27 +628,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||
// that are valid within the provided account.
|
||||
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||
result := make([]string, 0, len(postureChecksIds))
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
||||
validIDs := make([]string, 0, len(postureChecksIds))
|
||||
for _, id := range postureChecksIds {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if id == postureCheck.ID {
|
||||
result = append(result, id)
|
||||
continue
|
||||
}
|
||||
if _, exists := postureChecks[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||
result := make([]string, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if _, exists := account.Groups[groupID]; exists {
|
||||
result = append(result, groupID)
|
||||
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
|
||||
validIDs := make([]string, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if _, exists := groups[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@@ -854,24 +853,28 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
var policyWithGroupRulesNoPeers *Policy
|
||||
var policyWithDestinationPeersOnly *Policy
|
||||
var policyWithSourceAndDestinationPeers *Policy
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-rule-groups-no-peers",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
@@ -879,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -900,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with source group containing peers, but destination group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-has-peers-destination-none",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
@@ -914,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -935,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-destination-has-peers-source-none",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: false,
|
||||
Enabled: true,
|
||||
Sources: []string{"groupC"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
@@ -949,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -970,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
@@ -983,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1004,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Disabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
policyWithSourceAndDestinationPeers.Enabled = false
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1038,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||
// or send peer update
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Description: "updated description",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
policyWithSourceAndDestinationPeers.Description = "updated description"
|
||||
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1073,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Enabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
policyWithSourceAndDestinationPeers.Enabled = true
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1104,50 +1048,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged policy should trigger account peers update but not send peer update
|
||||
t.Run("saving unchanged policy", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy should trigger account peers update and send peer update
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policyID := "policy-source-destination-peers"
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1161,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Deleting policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1180,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
|
||||
}
|
||||
|
||||
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
||||
if postureChecksID == "" {
|
||||
postureChecksID = xid.New().String()
|
||||
}
|
||||
|
||||
postureChecks := Checks{
|
||||
ID: postureChecksID,
|
||||
Name: name,
|
||||
|
||||
@@ -2,16 +2,15 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
@@ -20,219 +19,279 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
||||
|
||||
// we do not allow create new posture checks with non uniq name
|
||||
if !exists && !uniqName {
|
||||
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
action := activity.PostureCheckCreated
|
||||
if exists {
|
||||
action = activity.PostureCheckUpdated
|
||||
account.Network.IncSerial()
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
var updateAccountPeers bool
|
||||
var isUpdate = postureChecks.ID != ""
|
||||
var action = activity.PostureCheckCreated
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// DeletePostureChecks deletes a posture check by ID.
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
||||
var postureChecks *posture.Checks
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPostureChecks returns a list of posture checks.
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||
uniqName = true
|
||||
for i, p := range account.PostureChecks {
|
||||
if !exists && p.ID == postureChecks.ID {
|
||||
account.PostureChecks[i] = postureChecks
|
||||
exists = true
|
||||
}
|
||||
if p.Name == postureChecks.Name {
|
||||
uniqName = false
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
account.PostureChecks = append(account.PostureChecks, postureChecks)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
|
||||
postureChecksIdx := -1
|
||||
for i, postureChecks := range account.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
postureChecksIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if postureChecksIdx < 0 {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
||||
}
|
||||
|
||||
// Check if posture check is linked to any policy
|
||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
||||
}
|
||||
|
||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
||||
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
|
||||
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
|
||||
peerPostureChecks := make(map[string]posture.Checks)
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(postureChecks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// validatePostureChecks validates the posture checks.
|
||||
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
// If the posture check already has an ID, verify its existence in the store.
|
||||
if postureChecks.ID != "" {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
// For new posture checks, ensure no duplicates by name.
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
|
||||
addPolicyPostureChecks(account, policy, peerPostureChecks)
|
||||
for _, check := range checks {
|
||||
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
|
||||
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
|
||||
}
|
||||
}
|
||||
|
||||
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
|
||||
for _, check := range peerPostureChecks {
|
||||
checkCopy := check
|
||||
postureChecksList = append(postureChecksList, &checkCopy)
|
||||
postureChecks.ID = xid.New().String()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return postureChecksList
|
||||
if !isInGroup {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
|
||||
func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, ok := account.Groups[sourceGroup]
|
||||
if ok && slices.Contains(group.Peers, peerID) {
|
||||
return true
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if postureCheck.ID == sourcePostureCheckID {
|
||||
peerPostureChecks[sourcePostureCheckID] = *postureCheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
||||
for _, policy := range account.Policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
||||
if !exists {
|
||||
return false
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
||||
if !isLinked {
|
||||
return false
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||
}
|
||||
}
|
||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
)
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
const (
|
||||
adminUserID = "adminUserID"
|
||||
regularUserID = "regularUserID"
|
||||
postureCheckID = "existing-id"
|
||||
postureCheckName = "Existing check"
|
||||
)
|
||||
|
||||
@@ -25,23 +25,22 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
accountID, err := initTestPostureChecksAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
t.Run("Generic posture check flow", func(t *testing.T) {
|
||||
// regular users can not create checks
|
||||
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
||||
_, err = am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{})
|
||||
assert.Error(t, err)
|
||||
|
||||
// regular users cannot list check
|
||||
_, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID)
|
||||
_, err = am.ListPostureChecks(context.Background(), accountID, regularUserID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// should be possible to create posture check with uniq name
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
postureCheck, err := am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
@@ -52,13 +51,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// admin users can list check
|
||||
checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||
checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, checks, 1)
|
||||
|
||||
// should not be possible to create posture check with non uniq name
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: "new-id",
|
||||
_, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||
@@ -73,53 +71,53 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
|
||||
// admins can update posture checks
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.27.0",
|
||||
},
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.27.0",
|
||||
},
|
||||
})
|
||||
}
|
||||
_, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// users should not be able to delete posture checks
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
|
||||
err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, regularUserID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// admin should be able to delete posture checks
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
|
||||
err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||
checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, checks, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) {
|
||||
accountID := "testingAccount"
|
||||
domain := "example.com"
|
||||
|
||||
admin := &User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
user := &User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
|
||||
{
|
||||
Id: adminUserID,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleAdmin,
|
||||
},
|
||||
{
|
||||
Id: regularUserID,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleUser,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
@@ -149,9 +147,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
postureCheck := posture.Checks{
|
||||
ID: "postureCheck",
|
||||
Name: "postureCheck",
|
||||
postureCheckA := &posture.Checks{
|
||||
Name: "postureCheckA",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
|
||||
require.NoError(t, err)
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
Name: "postureCheckB",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
@@ -168,7 +179,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -186,12 +197,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -201,12 +212,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
policy := Policy{
|
||||
ID: "policyA",
|
||||
policy := &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -214,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
}
|
||||
|
||||
// Linking posture check to policy should trigger update account peers and send peer update
|
||||
@@ -225,7 +234,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -237,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Updating linked posture checks should update account peers and send peer update
|
||||
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
@@ -254,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -264,25 +273,6 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged posture check should not trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing posture check from policy should trigger account peers update and send peer update
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
@@ -292,8 +282,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
policy.SourcePostureChecks = []string{}
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -311,7 +300,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
||||
err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -321,17 +310,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
@@ -339,9 +326,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -350,12 +336,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -372,12 +358,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -385,10 +370,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -397,12 +380,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -412,52 +395,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture check to policy where source has peers but destination does not,
|
||||
// should not trigger account peers update or send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -468,9 +409,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -479,7 +419,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
@@ -488,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -499,80 +439,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policyA",
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"checkA"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*group.Group{
|
||||
"groupA": {
|
||||
ID: "groupA",
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
"groupB": {
|
||||
ID: "groupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: "checkA",
|
||||
},
|
||||
{
|
||||
ID: "checkB",
|
||||
},
|
||||
},
|
||||
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
accountID, err := initTestPostureChecksAccount(manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
groupA := &group.Group{
|
||||
ID: "groupA",
|
||||
AccountID: accountID,
|
||||
Peers: []string{"peer1"},
|
||||
}
|
||||
|
||||
groupB := &group.Group{
|
||||
ID: "groupB",
|
||||
AccountID: accountID,
|
||||
Peers: []string{},
|
||||
}
|
||||
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
Name: "checkA",
|
||||
AccountID: accountID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||
},
|
||||
}
|
||||
postureCheckA, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA)
|
||||
require.NoError(t, err, "failed to save postureCheckA")
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
Name: "checkB",
|
||||
AccountID: accountID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||
},
|
||||
}
|
||||
postureCheckB, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB)
|
||||
require.NoError(t, err, "failed to save postureCheckB")
|
||||
|
||||
policy := &Policy{
|
||||
AccountID: accountID,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheckA.ID},
|
||||
}
|
||||
|
||||
policy, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to save policy")
|
||||
|
||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckB.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check does not exist", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, "unknown")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
policy.Rules[0].Sources = []string{"groupB"}
|
||||
policy.Rules[0].Destinations = []string{"groupA"}
|
||||
_, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
policy.Rules[0].Sources = []string{"groupA"}
|
||||
policy.Rules[0].Destinations = []string{"groupB"}
|
||||
_, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
groupA.Peers = []string{}
|
||||
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
account.Groups["groupA"].Peers = []string{}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
policy.Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
_, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -52,17 +53,46 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func getRoutesByPrefixOrDomains(ctx context.Context, transaction Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, r := range accountRoutes {
|
||||
dynamic := r.IsDynamic()
|
||||
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||
!dynamic && r.Network.String() == prefix.String() {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction Store, accountID string, checkRoute *route.Route, groupsMap map[string]*nbgroup.Group) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
prefix := checkRoute.Network
|
||||
domains := checkRoute.Domains
|
||||
|
||||
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@@ -71,18 +101,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
for _, prefixRoute := range routesWithPrefix {
|
||||
// we skip route(s) with the same network ID as we want to allow updating of the existing route
|
||||
// when creating a new route routeID is newly generated so nothing will be skipped
|
||||
if routeID == prefixRoute.ID {
|
||||
if checkRoute.ID == prefixRoute.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefixRoute.Peer != "" {
|
||||
seenPeers[string(prefixRoute.ID)] = true
|
||||
}
|
||||
|
||||
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, prefixRoute.PeerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range prefixRoute.PeerGroups {
|
||||
seenPeerGroups[groupID] = true
|
||||
|
||||
group := account.GetGroup(groupID)
|
||||
if group == nil {
|
||||
group, ok := peerGroupsMap[groupID]
|
||||
if !ok || group == nil {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||
getRouteDescriptor(prefix, domains), groupID,
|
||||
@@ -95,12 +131,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
}
|
||||
|
||||
if peerID != "" {
|
||||
if peerID := checkRoute.Peer; peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
_, err = transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeers[peerID]; ok {
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||
@@ -108,9 +145,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
|
||||
// check that peerGroupIDs are not in any route peerGroups list
|
||||
for _, groupID := range peerGroupIDs {
|
||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
||||
|
||||
for _, groupID := range checkRoute.PeerGroups {
|
||||
group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
|
||||
if _, ok := seenPeerGroups[groupID]; ok {
|
||||
return status.Errorf(
|
||||
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
|
||||
@@ -118,12 +154,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
peersMap, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, group.Peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range group.Peers {
|
||||
if _, ok := seenPeers[id]; ok {
|
||||
peer := account.GetPeer(id)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
peer, ok := peersMap[id]
|
||||
if !ok || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
|
||||
}
|
||||
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||
@@ -146,104 +188,63 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var newRoute *route.Route
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
newRoute = &route.Route{
|
||||
ID: route.ID(xid.New().String()),
|
||||
AccountID: accountID,
|
||||
Network: prefix,
|
||||
Domains: domains,
|
||||
KeepRoute: keepRoute,
|
||||
NetID: netID,
|
||||
Description: description,
|
||||
Peer: peerID,
|
||||
PeerGroups: peerGroupIDs,
|
||||
NetworkType: networkType,
|
||||
Masquerade: masquerade,
|
||||
Metric: metric,
|
||||
Enabled: enabled,
|
||||
Groups: groups,
|
||||
AccessControlGroups: accessControlGroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
if len(domains) > 0 && prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(domains) == 0 && !prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
|
||||
}
|
||||
|
||||
if len(domains) > 0 {
|
||||
prefix = getPlaceholderIP()
|
||||
}
|
||||
|
||||
if peerID != "" && len(peerGroupIDs) != 0 {
|
||||
return nil, status.Errorf(
|
||||
status.InvalidArgument,
|
||||
"peer with ID %s and peers group %s should not be provided at the same time",
|
||||
peerID, peerGroupIDs)
|
||||
}
|
||||
|
||||
var newRoute route.Route
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessControlGroupIDs) > 0 {
|
||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
return transaction.SaveRoute(ctx, LockingStrengthUpdate, newRoute)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
err = validateGroups(groups, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoute.Peer = peerID
|
||||
newRoute.PeerGroups = peerGroupIDs
|
||||
newRoute.Network = prefix
|
||||
newRoute.Domains = domains
|
||||
newRoute.NetworkType = networkType
|
||||
newRoute.Description = description
|
||||
newRoute.NetID = netID
|
||||
newRoute.Masquerade = masquerade
|
||||
newRoute.Metric = metric
|
||||
newRoute.Enabled = enabled
|
||||
newRoute.Groups = groups
|
||||
newRoute.KeepRoute = keepRoute
|
||||
newRoute.AccessControlGroups = accessControlGroupIDs
|
||||
|
||||
if account.Routes == nil {
|
||||
account.Routes = make(map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
account.Routes[newRoute.ID] = &newRoute
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
return &newRoute, nil
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newRoute, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves route
|
||||
@@ -251,10 +252,151 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var oldRoute *route.Route
|
||||
var oldRouteAffectsPeers bool
|
||||
var newRouteAffectsPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeToSave.ID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var route *route.Route
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
route, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID))
|
||||
})
|
||||
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateRoute(ctx context.Context, transaction Store, accountID string, routeToSave *route.Route) error {
|
||||
if routeToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
}
|
||||
|
||||
if err := validateRouteProperties(routeToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if routeToSave.Peer != "" {
|
||||
peer, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
|
||||
}
|
||||
|
||||
// Helper to validate route properties.
|
||||
func validateRouteProperties(routeToSave *route.Route) error {
|
||||
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
|
||||
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
@@ -263,18 +405,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
@@ -291,89 +421,34 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||
}
|
||||
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.AccessControlGroups) > 0 {
|
||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute := account.Routes[routeToSave.ID]
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routy := account.Routes[routeID]
|
||||
if routy == nil {
|
||||
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
|
||||
}
|
||||
delete(account.Routes, routeID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
if isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
// validateRouteGroups validates the route groups and returns the validated groups map.
|
||||
func validateRouteGroups(ctx context.Context, transaction Store, accountID string, routeToSave *route.Route) (map[string]*nbgroup.Group, error) {
|
||||
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
||||
groupsMap, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupsToValidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if len(routeToSave.AccessControlGroups) > 0 {
|
||||
if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
@@ -649,8 +724,21 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func areRouteChangesAffectPeers(ctx context.Context, transaction Store, route *route.Route) (bool, error) {
|
||||
if route.Peer != "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, route.AccountID, route.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, route.AccountID, route.PeerGroups)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -427,22 +429,22 @@ func TestCreateRoute(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Errorf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
if testCase.createInitRoute {
|
||||
groupAll, errInit := account.GetGroupAll()
|
||||
groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
|
||||
|
||||
_, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
|
||||
_, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
}
|
||||
|
||||
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||
|
||||
outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
if !testCase.shouldCreate {
|
||||
@@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) {
|
||||
|
||||
// assign generated ID
|
||||
testCase.expectedRoute.ID = outRoute.ID
|
||||
testCase.expectedRoute.AccountID = accountID
|
||||
|
||||
if !testCase.expectedRoute.IsEqual(outRoute) {
|
||||
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute)
|
||||
@@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
if testCase.createInitRoute {
|
||||
account.Routes["initRoute"] = &route.Route{
|
||||
initRoute := &route.Route{
|
||||
ID: "initRoute",
|
||||
AccountID: accountID,
|
||||
Network: existingNetwork,
|
||||
NetID: existingRouteID,
|
||||
NetworkType: route.IPv4Network,
|
||||
@@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) {
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute)
|
||||
require.NoError(t, err, "failed to save init route")
|
||||
}
|
||||
|
||||
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
testCase.existingRoute.AccountID = accountID
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute)
|
||||
require.NoError(t, err, "failed to save existing route")
|
||||
|
||||
var routeToSave *route.Route
|
||||
|
||||
@@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave)
|
||||
err = am.SaveRoute(context.Background(), accountID, userID, routeToSave)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
@@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
|
||||
require.True(t, saved)
|
||||
savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID)
|
||||
require.NoError(t, err, "failed to get saved route")
|
||||
|
||||
testCase.expectedRoute.AccountID = accountID
|
||||
if !testCase.expectedRoute.IsEqual(savedRoute) {
|
||||
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
|
||||
}
|
||||
@@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteRoute(t *testing.T) {
|
||||
testingRoute := &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
Domains: domain.List{"domain1", "domain2"},
|
||||
KeepRoute: true,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1Key,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
am, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.Routes[testingRoute.ID] = testingRoute
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "GroupA",
|
||||
AccountID: accountID,
|
||||
Name: "GroupA",
|
||||
})
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save account")
|
||||
testingRoute := &route.Route{
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: route.NetID("12345678901234567890qw"),
|
||||
Groups: []string{"GroupA"},
|
||||
KeepRoute: true,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID)
|
||||
err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID)
|
||||
if err != nil {
|
||||
t.Error("deleting route failed with error: ", err)
|
||||
}
|
||||
|
||||
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Error("failed to retrieve saved account with error: ", err)
|
||||
}
|
||||
|
||||
_, found := savedAccount.Routes[testingRoute.ID]
|
||||
if found {
|
||||
t.Error("route shouldn't be found after delete")
|
||||
}
|
||||
_, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID)
|
||||
require.NotNil(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, sErr.Type())
|
||||
}
|
||||
|
||||
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
@@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||
newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newRoute.Enabled, true)
|
||||
|
||||
@@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||
|
||||
groups, err := am.ListGroups(context.Background(), account.Id)
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
var groupHA1, groupHA2 *nbgroup.Group
|
||||
for _, group := range groups {
|
||||
@@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID)
|
||||
err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
|
||||
|
||||
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
|
||||
err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
|
||||
|
||||
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID)
|
||||
err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1158,7 +1153,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@@ -1167,7 +1162,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
|
||||
createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1181,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
expectedRoute := enabledRoute.Copy()
|
||||
expectedRoute.Peer = peer1Key
|
||||
|
||||
err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute)
|
||||
err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1193,7 +1188,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group")
|
||||
|
||||
err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID)
|
||||
err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
|
||||
@@ -1206,23 +1201,22 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
Name: "peer1 group",
|
||||
Peers: []string{peer1ID},
|
||||
}
|
||||
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup)
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, newGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
|
||||
rules, err := am.ListPolicies(context.Background(), accountID, "testingUser")
|
||||
require.NoError(t, err)
|
||||
|
||||
defaultRule := rules[0]
|
||||
newPolicy := defaultRule.Copy()
|
||||
newPolicy.ID = xid.New().String()
|
||||
newPolicy.Name = "peer1 only"
|
||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||
_, err = am.SavePolicy(context.Background(), accountID, userID, newPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||
err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1233,7 +1227,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID)
|
||||
err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1267,179 +1261,104 @@ func createRouterStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
ips := account.getTakenIPs()
|
||||
peer1IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) {
|
||||
ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerIP, err := AllocatePeerIP(network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
IP: peerIP,
|
||||
AccountID: accountID,
|
||||
ID: peerID,
|
||||
Key: peerKey,
|
||||
Name: peerName,
|
||||
DNSLabel: dnsLabel,
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: peerName,
|
||||
GoOS: strings.ToLower(kernel),
|
||||
Kernel: kernel,
|
||||
Core: core,
|
||||
Platform: platform,
|
||||
OS: os,
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// Create peers
|
||||
peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
IP: peer1IP,
|
||||
ID: peer1ID,
|
||||
Key: peer1Key,
|
||||
Name: "test-host1@netbird.io",
|
||||
DNSLabel: "test-host1",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host1@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer1.ID] = peer1
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
IP: peer2IP,
|
||||
ID: peer2ID,
|
||||
Key: peer2Key,
|
||||
Name: "test-host2@netbird.io",
|
||||
DNSLabel: "test-host2",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host2@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer2.ID] = peer2
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer3 := &nbpeer.Peer{
|
||||
IP: peer3IP,
|
||||
ID: peer3ID,
|
||||
Key: peer3Key,
|
||||
Name: "test-host3@netbird.io",
|
||||
DNSLabel: "test-host3",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host3@netbird.io",
|
||||
GoOS: "darwin",
|
||||
Kernel: "Darwin",
|
||||
Core: "13.4.1",
|
||||
Platform: "arm64",
|
||||
OS: "darwin",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer3.ID] = peer3
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer4 := &nbpeer.Peer{
|
||||
IP: peer4IP,
|
||||
ID: peer4ID,
|
||||
Key: peer4Key,
|
||||
Name: "test-host4@netbird.io",
|
||||
DNSLabel: "test-host4",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host4@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer4.ID] = peer4
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer5 := &nbpeer.Peer{
|
||||
IP: peer5IP,
|
||||
ID: peer5ID,
|
||||
Key: peer5Key,
|
||||
Name: "test-host5@netbird.io",
|
||||
DNSLabel: "test-host5",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host5@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
groupAll, err := am.GetGroupByName(context.Background(), "All", accountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
account.Peers[peer5.ID] = peer5
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupAll, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
newGroup := []*nbgroup.Group{
|
||||
newGroups := []*nbgroup.Group{
|
||||
{
|
||||
ID: routeGroup1,
|
||||
Name: routeGroup1,
|
||||
@@ -1471,15 +1390,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
Peers: []string{peer1.ID, peer4.ID},
|
||||
},
|
||||
}
|
||||
|
||||
for _, group := range newGroup {
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.SaveGroups(context.Background(), accountID, userID, newGroups)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
@@ -1783,10 +1699,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
manager, err := createRouterManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestRouteAccount(t, manager)
|
||||
accountID, err := initTestRouteAccount(t, manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -1832,7 +1748,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
@@ -1868,7 +1784,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
@@ -1904,7 +1820,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
newRoute, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
|
||||
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
|
||||
)
|
||||
@@ -1928,7 +1844,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1938,26 +1854,6 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unchanged route should update account peers and not send peer update
|
||||
t.Run("updating unchanged route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting the route should update account peers and send peer update
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
@@ -1966,7 +1862,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
|
||||
err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1990,7 +1886,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
@@ -2002,7 +1898,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1ID},
|
||||
@@ -2030,7 +1926,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
Groups: []string{"groupC"},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
@@ -2042,7 +1938,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1ID},
|
||||
|
||||
@@ -4,18 +4,17 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -229,32 +228,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
|
||||
return nil, err
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
var plainKey string
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
setupKey.AccountID = accountID
|
||||
|
||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed adding account key")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
|
||||
|
||||
for _, g := range setupKey.AutoGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
}
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
// for the creation return the plain key to the caller
|
||||
@@ -268,43 +278,56 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyToSave.Id {
|
||||
oldKey = key.Copy()
|
||||
break
|
||||
var newKey *SetupKey
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if oldKey == nil {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newKey := oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newKey = oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
|
||||
account.SetupKeys[newKey.Key] = newKey
|
||||
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -312,30 +335,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta())
|
||||
}
|
||||
|
||||
defer func() {
|
||||
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
return newKey, nil
|
||||
}
|
||||
@@ -347,16 +349,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return setupKeys, nil
|
||||
return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
@@ -366,8 +367,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
@@ -387,21 +392,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var deletedSetupKey *SetupKey
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete setup key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
|
||||
@@ -409,15 +422,62 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
for _, group := range autoGroups {
|
||||
g, ok := account.Groups[group]
|
||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range autoGroupIDs {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
return status.Errorf(status.NotFound, "group not found: %s", groupID)
|
||||
}
|
||||
if g.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
|
||||
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareSetupKeyEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range removedGroups {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
|
||||
})
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
@@ -25,12 +25,12 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
@@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
expiresIn := time.Hour
|
||||
keyName := "my-test-key"
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{},
|
||||
key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{},
|
||||
SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
newKeyName := "my-new-test-key"
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
@@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
key.Id, time.Now().UTC(), autoGroups, true)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked)
|
||||
ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked)
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, accountID, ev.AccountID)
|
||||
assert.Equal(t, newKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
assert.Equal(t, key.Id, ev.TargetID)
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// saving setup key with All group assigned to auto groups should return error
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
_, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
@@ -103,12 +103,12 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
@@ -117,7 +117,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
@@ -126,7 +126,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
@@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
@@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
tCase.expectedUpdatedAt, tCase.expectedGroups, false)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated)
|
||||
ev := getEvent(t, accountID, manager, activity.SetupKeyCreated)
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, accountID, ev.AccountID)
|
||||
assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
@@ -208,12 +208,12 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
@@ -222,7 +222,7 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
@@ -390,8 +390,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
policy := &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -403,7 +402,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -86,9 +86,14 @@ func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
||||
func NewPeerNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "peer is not part of this account")
|
||||
}
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
func NewUserNotFoundError(userKey string) error {
|
||||
return Errorf(NotFound, "user not found: %s", userKey)
|
||||
return Errorf(NotFound, "user: %s not found", userKey)
|
||||
}
|
||||
|
||||
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
||||
@@ -102,25 +107,74 @@ func NewPeerLoginExpiredError() error {
|
||||
}
|
||||
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError(err error) error {
|
||||
return Errorf(NotFound, "setup key not found: %s", err)
|
||||
func NewSetupKeyNotFoundError(setupKeyID string) error {
|
||||
return Errorf(NotFound, "setup key: %s not found", setupKeyID)
|
||||
}
|
||||
|
||||
func NewGetAccountFromStoreError(err error) error {
|
||||
return Errorf(Internal, "issue getting account from store: %s", err)
|
||||
}
|
||||
|
||||
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
|
||||
func NewUserNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "user is not part of this account")
|
||||
}
|
||||
|
||||
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||
func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
|
||||
func NewAdminPermissionError() error {
|
||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
}
|
||||
|
||||
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
|
||||
func NewUnauthorizedToViewSetupKeysError() error {
|
||||
return Errorf(Unauthorized, "only users with admin power can view setup keys")
|
||||
// NewGetAccountError creates a new Error with Internal type for an issue getting account
|
||||
func NewGetAccountError(err error) error {
|
||||
return Errorf(Internal, "error getting account: %s", err)
|
||||
}
|
||||
|
||||
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
|
||||
func NewGroupNotFoundError(groupID string) error {
|
||||
return Errorf(NotFound, "group: %s not found", groupID)
|
||||
}
|
||||
|
||||
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
|
||||
func NewPostureChecksNotFoundError(postureChecksID string) error {
|
||||
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
|
||||
}
|
||||
|
||||
// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
|
||||
func NewPolicyNotFoundError(policyID string) error {
|
||||
return Errorf(NotFound, "policy: %s not found", policyID)
|
||||
}
|
||||
|
||||
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
|
||||
func NewNameServerGroupNotFoundError(nsGroupID string) error {
|
||||
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
|
||||
}
|
||||
|
||||
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
|
||||
func NewServiceUserRoleInvalidError() error {
|
||||
return Errorf(InvalidArgument, "can't create a service user with owner role")
|
||||
}
|
||||
|
||||
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
|
||||
// to delete a user with the owner role.
|
||||
func NewOwnerDeletePermissionError() error {
|
||||
return Errorf(PermissionDenied, "can't delete a user with the owner role")
|
||||
}
|
||||
|
||||
func NewPATNotFoundError(patID string) error {
|
||||
return Errorf(NotFound, "PAT: %s not found", patID)
|
||||
}
|
||||
|
||||
func NewRouteNotFoundError(routeID string) error {
|
||||
return Errorf(NotFound, "route: %s not found", routeID)
|
||||
}
|
||||
|
||||
@@ -48,64 +48,98 @@ type Store interface {
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(userID string) (string, error)
|
||||
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
||||
GetTotalAccounts(ctx context.Context) (int64, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
|
||||
CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
|
||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
||||
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
|
||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
@@ -124,7 +158,6 @@ type Store interface {
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
|
||||
@@ -34,4 +34,8 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
||||
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
||||
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
||||
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
2
management/server/testdata/store.sql
vendored
2
management/server/testdata/store.sql
vendored
@@ -25,7 +25,7 @@ CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
|
||||
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
|
||||
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
@@ -27,9 +27,13 @@ CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
|
||||
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
2
management/server/testdata/storev1.sql
vendored
2
management/server/testdata/storev1.sql
vendored
@@ -34,6 +34,6 @@ INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038'
|
||||
INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',1,'""','','',0);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
|
||||
@@ -2,13 +2,9 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/differs"
|
||||
"github.com/r3labs/diff/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -25,8 +21,6 @@ type UpdateMessage struct {
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||
peerUpdateMessage map[string]*UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
@@ -36,10 +30,9 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,15 +41,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
// skip sending sync update to the peer if there is no change in update message,
|
||||
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||
if update.NetworkMap != nil {
|
||||
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.channelsMux.Lock()
|
||||
|
||||
defer func() {
|
||||
@@ -66,16 +50,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if update.NetworkMap != nil {
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
|
||||
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||
return
|
||||
}
|
||||
p.peerUpdateMessage[peerID] = update
|
||||
}
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
@@ -108,7 +82,6 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
@@ -123,10 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
|
||||
}
|
||||
|
||||
// CloseChannels closes updates channel for each given peer
|
||||
@@ -200,72 +175,3 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||
p.channelsMux.RLock()
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
p.channelsMux.RUnlock()
|
||||
|
||||
if lastSentUpdate != nil {
|
||||
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||
return false
|
||||
}
|
||||
if !updated {
|
||||
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
|
||||
isNew, err = true, nil
|
||||
}
|
||||
}()
|
||||
|
||||
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
differ, err := diff.NewDiffer(
|
||||
diff.CustomValueDiffers(&differs.NetIPAddr{}),
|
||||
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create differ: %v", err)
|
||||
}
|
||||
|
||||
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
|
||||
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
|
||||
|
||||
changelog, err := differ.Diff(lastSentFiles, currFiles)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||
}
|
||||
return len(changelog) > 0, nil
|
||||
}
|
||||
|
||||
// getChecksFiles returns a list of files from the given checks.
|
||||
func getChecksFiles(checks []*proto.Checks) []string {
|
||||
files := make([]string, 0, len(checks))
|
||||
for _, check := range checks {
|
||||
files = append(files, check.GetFiles()...)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
@@ -2,19 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// var peersUpdater *PeersUpdateManager
|
||||
@@ -86,470 +77,3 @@ func TestCloseChannel(t *testing.T) {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
existingUpdate *UpdateMessage
|
||||
newUpdate *UpdateMessage
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "update message with turn credentials update",
|
||||
peerID: "peer",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||
},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message for peer without existing update",
|
||||
peerID: "peer1",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with no changes in update",
|
||||
peerID: "peer2",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "update message with changes in checks",
|
||||
peerID: "peer3",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
Checks: []*proto.Checks{
|
||||
{
|
||||
Files: []string{"/usr/bin/netbird"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with lower serial number",
|
||||
peerID: "peer4",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewPeersUpdateManager(nil)
|
||||
ctx := context.Background()
|
||||
|
||||
if tt.existingUpdate != nil {
|
||||
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||
}
|
||||
|
||||
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNewPeerUpdateMessage(t *testing.T) {
|
||||
t.Run("Unchanged value", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating routes network", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Updating routes groups", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating network map peers", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.4"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer4-key",
|
||||
DNSLabel: "peer4",
|
||||
SSHKey: "peer4-ssh-key",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating process check", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
|
||||
newUpdateMessage3 := createMockUpdateMessage(t)
|
||||
newUpdateMessage3.Update.Checks = []*proto.Checks{}
|
||||
newUpdateMessage3.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage4 := createMockUpdateMessage(t)
|
||||
check := &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/netbird",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage4.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage5 := createMockUpdateMessage(t)
|
||||
check = &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/local/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage5.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating DNS configuration", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newDomain := "newexample.com"
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
|
||||
newDomain,
|
||||
)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating peer IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Add new firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newRule := &FirewallRule{
|
||||
PeerIP: "192.168.1.3",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Action: string(PolicyTrafficActionDrop),
|
||||
Protocol: string(PolicyRuleProtocolUDP),
|
||||
Port: "53",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Removing nameserver", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating name server IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating custom DNS zone", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
|
||||
t.Helper()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
|
||||
secretManager := NewTimeBasedAuthSecretsManager(
|
||||
NewPeersUpdateManager(nil),
|
||||
&TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{
|
||||
Duration: defaultDuration,
|
||||
},
|
||||
Secret: "secret",
|
||||
Turns: []*Host{TurnTestHost},
|
||||
},
|
||||
&Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "secret",
|
||||
},
|
||||
)
|
||||
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipNet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
{
|
||||
ID: "route2",
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route2",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
turnToken, err := secretManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
relayToken, err := secretManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &UpdateMessage{
|
||||
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
|
||||
NetworkMap: networkMap,
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -43,37 +43,34 @@ const (
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
assert.Equal(t, newPAT.CreatedBy, mockUserID)
|
||||
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
|
||||
pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting token ID by hashed token: %s", err)
|
||||
}
|
||||
|
||||
if tokenID == "" {
|
||||
if pat.ID == "" {
|
||||
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.ID, tokenID)
|
||||
assert.Equal(t, newPAT.ID, pat.ID)
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||
}
|
||||
@@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockTargetUserId,
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockTargetUserId,
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: true,
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
|
||||
assert.Errorf(t, err, "Wrong expiration should thorw error")
|
||||
assert.Errorf(t, err, "Wrong expiration should throw error")
|
||||
}
|
||||
|
||||
func TestUser_DeletePAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID1,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken1,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||
account, err := store.GetAccount(context.Background(), mockAccountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
@@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
func TestUser_GetPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID1,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken1,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
func TestUser_GetAllPATs(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
mockTokenID2: {
|
||||
ID: mockTokenID2,
|
||||
HashedToken: mockToken2,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID1,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken1,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID2,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken2,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) {
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when creating service user: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||
account, err := store.GetAccount(context.Background(), mockAccountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(account.Users))
|
||||
@@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when creating user: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||
account, err := store.GetAccount(context.Background(), mockAccountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, user.IsServiceUser)
|
||||
@@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
func TestUser_InviteNewUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -549,13 +519,13 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
tt.serviceUser.AccountID = mockAccountID
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser)
|
||||
assert.NoError(t, err, "failed to create service user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -582,12 +552,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -603,39 +570,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
|
||||
{
|
||||
Id: "user2",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
},
|
||||
{
|
||||
Id: "user3",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user4",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
},
|
||||
{
|
||||
Id: "user5",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err, "failed to save users")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -685,61 +651,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
}
|
||||
account.Users["user6"] = &User{
|
||||
Id: "user6",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user7"] = &User{
|
||||
Id: "user7",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user8"] = &User{
|
||||
Id: "user8",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
account.Users["user9"] = &User{
|
||||
Id: "user9",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
|
||||
{
|
||||
Id: "user2",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
},
|
||||
{
|
||||
Id: "user3",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user4",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
},
|
||||
{
|
||||
Id: "user5",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
},
|
||||
{
|
||||
Id: "user6",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user7",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user8",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
},
|
||||
{
|
||||
Id: "user9",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -786,7 +755,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
{
|
||||
name: "Delete non-existent user",
|
||||
userIDs: []string{"non-existent-user"},
|
||||
expectedReasons: []string{"target user: non-existent-user not found"},
|
||||
expectedReasons: []string{"user: non-existent-user not found"},
|
||||
expectedNotDeleted: []string{},
|
||||
},
|
||||
{
|
||||
@@ -816,7 +785,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
acc, err := am.Store.GetAccount(context.Background(), mockAccountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, id := range tc.expectedDeleted {
|
||||
@@ -836,12 +805,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -865,14 +831,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
newUser := NewRegularUser("normal_user1")
|
||||
newUser.AccountID = mockAccountID
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
newUser = NewRegularUser("normal_user2")
|
||||
newUser.AccountID = mockAccountID
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -946,15 +917,25 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
delete(account.Users, mockUserID)
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
||||
newUser.AccountID = mockAccountID
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID)
|
||||
assert.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
|
||||
err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings)
|
||||
assert.NoError(t, err, "failed to save account settings")
|
||||
|
||||
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID)
|
||||
assert.NoError(t, err, "failed to delete user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -968,7 +949,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 1, len(users))
|
||||
|
||||
userInfo, _ := users[0].ToUserInfo(nil, account.Settings)
|
||||
userInfo, _ := users[0].ToUserInfo(nil, settings)
|
||||
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
|
||||
})
|
||||
}
|
||||
@@ -978,22 +959,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
externalUser := &User{
|
||||
Id: "externalUser",
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedIntegration,
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: "externalUser",
|
||||
AccountID: mockAccountID,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedIntegration,
|
||||
IntegrationReference: integration_reference.IntegrationReference{
|
||||
ID: 1,
|
||||
IntegrationType: "external",
|
||||
},
|
||||
}
|
||||
account.Users[externalUser.Id] = externalUser
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -1013,6 +993,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cacheManager := am.GetExternalCacheManager()
|
||||
|
||||
externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser")
|
||||
assert.NoError(t, err, "failed to get user")
|
||||
|
||||
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"})
|
||||
assert.NoError(t, err)
|
||||
@@ -1042,17 +1026,17 @@ func TestUser_IsAdmin(t *testing.T) {
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockServiceUserID,
|
||||
AccountID: mockAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -1071,17 +1055,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockServiceUserID,
|
||||
AccountID: mockAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -1240,21 +1223,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
// create an account and an admin user
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create other users
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||
account.Users[adminUserID] = NewAdminUser(adminUserID)
|
||||
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
regularUser := NewRegularUser(regularUserID)
|
||||
regularUser.AccountID = accountID
|
||||
|
||||
adminUser := NewAdminUser(adminUserID)
|
||||
adminUser.AccountID = accountID
|
||||
|
||||
serviceUser := &User{
|
||||
Id: serviceUserID,
|
||||
AccountID: accountID,
|
||||
IsServiceUser: true,
|
||||
Role: UserRoleAdmin,
|
||||
ServiceUserName: "service",
|
||||
}
|
||||
|
||||
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update)
|
||||
err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser})
|
||||
assert.NoError(t, err, "failed to save users")
|
||||
|
||||
updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update)
|
||||
if tc.expectedErr {
|
||||
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||
} else {
|
||||
@@ -1279,8 +1271,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
policy := &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1292,7 +1283,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
|
||||
@@ -3,7 +3,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
||||
conn, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
if conn.conn != connReference {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
@@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
if container.conn != connReference {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
keepUnusedServerTime = 5 * time.Second
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
@@ -27,10 +28,13 @@ type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
err error
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{}
|
||||
return &RelayTrack{
|
||||
created: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
@@ -302,6 +306,18 @@ func (m *Manager) cleanUpUnusedRelays() {
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
// if the connection failed to the server the relay client will be nil
|
||||
// but the instance will be kept in the relayClients until the next locking
|
||||
if rt.err != nil {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(rt.created) <= keepUnusedServerTime {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
|
||||
@@ -288,8 +288,9 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
time.Sleep(timeout)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
@@ -13,7 +12,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
PeerID: "test",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
|
||||
|
||||
func (c *Conn) ioErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return io.EOF
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
var wErr *websocket.CloseError
|
||||
@@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
|
||||
return err
|
||||
}
|
||||
if wErr.Code == websocket.StatusNormalClosure {
|
||||
return io.EOF
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -63,13 +63,14 @@ func (l *Listener) Shutdown(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
connRemoteAddr := remoteAddr(r)
|
||||
wsConn, err := websocket.Accept(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
|
||||
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
|
||||
if err != nil {
|
||||
err = wsConn.Close(websocket.StatusInternalError, "internal error")
|
||||
if err != nil {
|
||||
@@ -90,3 +91,10 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
conn := NewConn(wsConn, lAddr, rAddr)
|
||||
l.acceptFn(conn)
|
||||
}
|
||||
|
||||
func remoteAddr(r *http.Request) string {
|
||||
if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -57,7 +57,7 @@ func (p *Peer) Work() {
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
|
||||
@@ -88,18 +88,18 @@ type Route struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// Network and Domains are mutually exclusive
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
KeepRoute bool
|
||||
NetID NetID
|
||||
Description string
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
NetworkType NetworkType
|
||||
Masquerade bool
|
||||
Metric int
|
||||
Enabled bool
|
||||
Groups []string `gorm:"serializer:json"`
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
KeepRoute bool
|
||||
NetID NetID
|
||||
Description string
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
NetworkType NetworkType
|
||||
Masquerade bool
|
||||
Metric int
|
||||
Enabled bool
|
||||
Groups []string `gorm:"serializer:json"`
|
||||
AccessControlGroups []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
@@ -111,19 +111,20 @@ func (r *Route) EventMeta() map[string]any {
|
||||
// Copy copies a route object
|
||||
func (r *Route) Copy() *Route {
|
||||
route := &Route{
|
||||
ID: r.ID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
Network: r.Network,
|
||||
Domains: slices.Clone(r.Domains),
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: r.NetworkType,
|
||||
Peer: r.Peer,
|
||||
PeerGroups: slices.Clone(r.PeerGroups),
|
||||
Metric: r.Metric,
|
||||
Masquerade: r.Masquerade,
|
||||
Enabled: r.Enabled,
|
||||
Groups: slices.Clone(r.Groups),
|
||||
ID: r.ID,
|
||||
AccountID: r.AccountID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
Network: r.Network,
|
||||
Domains: slices.Clone(r.Domains),
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: r.NetworkType,
|
||||
Peer: r.Peer,
|
||||
PeerGroups: slices.Clone(r.PeerGroups),
|
||||
Metric: r.Metric,
|
||||
Masquerade: r.Masquerade,
|
||||
Enabled: r.Enabled,
|
||||
Groups: slices.Clone(r.Groups),
|
||||
AccessControlGroups: slices.Clone(r.AccessControlGroups),
|
||||
}
|
||||
return route
|
||||
@@ -149,7 +150,7 @@ func (r *Route) IsEqual(other *Route) bool {
|
||||
other.Masquerade == r.Masquerade &&
|
||||
other.Enabled == r.Enabled &&
|
||||
slices.Equal(r.Groups, other.Groups) &&
|
||||
slices.Equal(r.PeerGroups, other.PeerGroups)&&
|
||||
slices.Equal(r.PeerGroups, other.PeerGroups) &&
|
||||
slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,11 @@ import (
|
||||
|
||||
const (
|
||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||
NetbirdFwmark = 0x1BD00
|
||||
PreroutingFwmark = 0x1BD01
|
||||
NetbirdFwmark = 0x1BD00
|
||||
|
||||
PreroutingFwmarkRedirected = 0x1BD01
|
||||
PreroutingFwmarkMasquerade = 0x1BD11
|
||||
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
||||
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user