Compare commits

...

1 Commits

Author SHA1 Message Date
crn4
23d4de3957 [management] check policy for changes before actual db update 2026-02-20 15:20:59 +01:00
5 changed files with 680 additions and 26 deletions

View File

@@ -46,25 +46,46 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var action = activity.PolicyAdded
var unchanged bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil {
return err
}
saveFunc := transaction.CreatePolicy
if isUpdate {
action = activity.PolicyUpdated
saveFunc = transaction.SavePolicy
}
existingPolicy, getErr := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if getErr != nil {
return getErr
}
if err = saveFunc(ctx, policy); err != nil {
return err
existingPolicy.Normalize()
policy.Normalize()
if policy.Equal(existingPolicy) {
unchanged = true
return nil
}
action = activity.PolicyUpdated
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
}
return transaction.IncrementNetworkSerial(ctx, accountID)
@@ -73,6 +94,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return nil, err
}
if unchanged {
return policy, nil
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
@@ -146,26 +171,38 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
return false, err
}
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
return arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil

View File

@@ -3,6 +3,7 @@ package types
import (
"errors"
"fmt"
"slices"
"strconv"
"strings"
)
@@ -93,6 +94,54 @@ func (p *Policy) Copy() *Policy {
return c
}
func (p *Policy) Equal(other *Policy) bool {
if p == nil || other == nil {
return p == other
}
if p.ID != other.ID ||
p.AccountID != other.AccountID ||
p.Name != other.Name ||
p.Description != other.Description ||
p.Enabled != other.Enabled {
return false
}
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
return false
}
if len(p.Rules) != len(other.Rules) {
return false
}
otherRules := make(map[string]*PolicyRule, len(other.Rules))
for _, r := range other.Rules {
otherRules[r.ID] = r
}
for _, r := range p.Rules {
otherRule, ok := otherRules[r.ID]
if !ok {
return false
}
if !r.Equal(otherRule) {
return false
}
}
return true
}
func (p *Policy) Normalize() {
if p == nil {
return
}
slices.Sort(p.SourcePostureChecks)
for _, r := range p.Rules {
r.Normalize()
}
}
// EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name}

View File

@@ -0,0 +1,218 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
a := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
},
}
b := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
},
}
assert.True(t, a.Equal(b))
}
func TestPolicyEqual_DifferentRules(t *testing.T) {
a := &Policy{
ID: "pol1",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
},
}
b := &Policy{
ID: "pol1",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
a := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1"},
},
}
b := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1"},
{ID: "r2", PolicyID: "pol1"},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
a := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
}
b := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
}
assert.True(t, a.Equal(b))
}
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
a := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc1", "pc2"},
}
b := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc1", "pc3"},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
base := Policy{
ID: "pol1",
AccountID: "acc1",
Name: "test",
Description: "desc",
Enabled: true,
}
other := base
other.Name = "changed"
assert.False(t, base.Equal(&other))
other = base
other.Enabled = false
assert.False(t, base.Equal(&other))
other = base
other.Description = "changed"
assert.False(t, base.Equal(&other))
}
func TestPolicyEqual_NilCases(t *testing.T) {
var a *Policy
var b *Policy
assert.True(t, a.Equal(b))
a = &Policy{ID: "pol1"}
assert.False(t, a.Equal(nil))
}
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
a := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1"},
},
}
b := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r2", PolicyID: "pol1"},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyNormalize(t *testing.T) {
p := &Policy{
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
Rules: []*PolicyRule{
{
ID: "r1",
Sources: []string{"g2", "g1"},
Destinations: []string{"g4", "g3"},
Ports: []string{"443", "80"},
},
},
}
p.Normalize()
assert.Equal(t, []string{"pc1", "pc2", "pc3"}, p.SourcePostureChecks)
assert.Equal(t, []string{"g1", "g2"}, p.Rules[0].Sources)
assert.Equal(t, []string{"g3", "g4"}, p.Rules[0].Destinations)
assert.Equal(t, []string{"443", "80"}, p.Rules[0].Ports)
}
func TestPolicyNormalize_Nil(t *testing.T) {
var p *Policy
p.Normalize()
}
func TestPolicyEqual_FullScenario(t *testing.T) {
a := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "Web Access",
Description: "Allow web access",
Enabled: true,
SourcePostureChecks: []string{"pc2", "pc1"},
Rules: []*PolicyRule{
{
ID: "r1",
PolicyID: "pol1",
Name: "HTTP",
Enabled: true,
Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP,
Bidirectional: true,
Sources: []string{"g2", "g1"},
Destinations: []string{"g4", "g3"},
Ports: []string{"443", "80", "8080"},
PortRanges: []RulePortRange{
{Start: 8000, End: 9000},
{Start: 80, End: 80},
},
},
},
}
b := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "Web Access",
Description: "Allow web access",
Enabled: true,
SourcePostureChecks: []string{"pc1", "pc2"},
Rules: []*PolicyRule{
{
ID: "r1",
PolicyID: "pol1",
Name: "HTTP",
Enabled: true,
Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP,
Bidirectional: true,
Sources: []string{"g1", "g2"},
Destinations: []string{"g3", "g4"},
Ports: []string{"80", "8080", "443"},
PortRanges: []RulePortRange{
{Start: 80, End: 80},
{Start: 8000, End: 9000},
},
},
},
}
assert.True(t, a.Equal(b))
}

View File

@@ -1,6 +1,9 @@
package types
import (
"slices"
"sort"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -118,3 +121,125 @@ func (pm *PolicyRule) Copy() *PolicyRule {
}
return rule
}
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
if pm == nil || other == nil {
return pm == other
}
if pm.ID != other.ID ||
pm.PolicyID != other.PolicyID ||
pm.Name != other.Name ||
pm.Description != other.Description ||
pm.Enabled != other.Enabled ||
pm.Action != other.Action ||
pm.Bidirectional != other.Bidirectional ||
pm.Protocol != other.Protocol ||
pm.SourceResource != other.SourceResource ||
pm.DestinationResource != other.DestinationResource ||
pm.AuthorizedUser != other.AuthorizedUser {
return false
}
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
return false
}
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
return false
}
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
return false
}
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
return false
}
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
return false
}
return true
}
func (pm *PolicyRule) Normalize() {
if pm == nil {
return
}
slices.Sort(pm.Sources)
slices.Sort(pm.Destinations)
slices.Sort(pm.Ports)
sort.Slice(pm.PortRanges, func(i, j int) bool {
if pm.PortRanges[i].Start != pm.PortRanges[j].Start {
return pm.PortRanges[i].Start < pm.PortRanges[j].Start
}
return pm.PortRanges[i].End < pm.PortRanges[j].End
})
for k, v := range pm.AuthorizedGroups {
slices.Sort(v)
pm.AuthorizedGroups[k] = v
}
}
func stringSlicesEqualUnordered(a, b []string) bool {
if len(a) != len(b) {
return false
}
if len(a) == 0 {
return true
}
sorted1 := make([]string, len(a))
sorted2 := make([]string, len(b))
copy(sorted1, a)
copy(sorted2, b)
slices.Sort(sorted1)
slices.Sort(sorted2)
return slices.Equal(sorted1, sorted2)
}
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
if len(a) != len(b) {
return false
}
if len(a) == 0 {
return true
}
cmp := func(x, y RulePortRange) int {
if x.Start != y.Start {
if x.Start < y.Start {
return -1
}
return 1
}
if x.End != y.End {
if x.End < y.End {
return -1
}
return 1
}
return 0
}
sorted1 := make([]RulePortRange, len(a))
sorted2 := make([]RulePortRange, len(b))
copy(sorted1, a)
copy(sorted2, b)
slices.SortFunc(sorted1, cmp)
slices.SortFunc(sorted2, cmp)
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
return x.Start == y.Start && x.End == y.End
})
}
func authorizedGroupsEqual(a, b map[string][]string) bool {
if len(a) != len(b) {
return false
}
for k, va := range a {
vb, ok := b[k]
if !ok {
return false
}
if !stringSlicesEqualUnordered(va, vb) {
return false
}
}
return true
}

View File

@@ -0,0 +1,225 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"443", "80", "22"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"22", "443", "80"},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"443", "80"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"443", "22"},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g1", "g2", "g3"},
Destinations: []string{"g4", "g5"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g3", "g1", "g2"},
Destinations: []string{"g5", "g4"},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g1", "g2"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g1", "g3"},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 8000, End: 9000},
{Start: 80, End: 80},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 80, End: 80},
{Start: 8000, End: 9000},
},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 80, End: 80},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 80, End: 443},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g1": {"u1", "u2", "u3"},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g1": {"u3", "u1", "u2"},
},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g1": {"u1"},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g2": {"u1"},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
base := PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Name: "test",
Description: "desc",
Enabled: true,
Action: PolicyTrafficActionAccept,
Bidirectional: true,
Protocol: PolicyRuleProtocolTCP,
}
other := base
other.Name = "changed"
assert.False(t, base.Equal(&other))
other = base
other.Enabled = false
assert.False(t, base.Equal(&other))
other = base
other.Action = PolicyTrafficActionDrop
assert.False(t, base.Equal(&other))
other = base
other.Protocol = PolicyRuleProtocolUDP
assert.False(t, base.Equal(&other))
}
func TestPolicyRuleEqual_NilCases(t *testing.T) {
var a *PolicyRule
var b *PolicyRule
assert.True(t, a.Equal(b))
a = &PolicyRule{ID: "rule1"}
assert.False(t, a.Equal(nil))
}
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{},
Sources: nil,
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: nil,
Sources: []string{},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleNormalize(t *testing.T) {
rule := &PolicyRule{
Sources: []string{"g3", "g1", "g2"},
Destinations: []string{"g6", "g4", "g5"},
Ports: []string{"443", "80", "22"},
PortRanges: []RulePortRange{
{Start: 8000, End: 9000},
{Start: 80, End: 80},
{Start: 80, End: 443},
},
AuthorizedGroups: map[string][]string{
"g1": {"u3", "u1", "u2"},
},
}
rule.Normalize()
assert.Equal(t, []string{"g1", "g2", "g3"}, rule.Sources)
assert.Equal(t, []string{"g4", "g5", "g6"}, rule.Destinations)
assert.Equal(t, []string{"22", "443", "80"}, rule.Ports)
assert.Equal(t, []RulePortRange{
{Start: 80, End: 80},
{Start: 80, End: 443},
{Start: 8000, End: 9000},
}, rule.PortRanges)
assert.Equal(t, []string{"u1", "u2", "u3"}, rule.AuthorizedGroups["g1"])
}
func TestPolicyRuleNormalize_Nil(t *testing.T) {
var rule *PolicyRule
rule.Normalize()
}