Add context to throughout the project and update logging (#2209)

propagate context from all the API calls and log request ID, account ID and peer ID

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
pascal-fischer
2024-07-03 11:33:02 +02:00
committed by GitHub
parent 7cb81f1d70
commit 765aba2c1c
127 changed files with 2936 additions and 2642 deletions

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
_ "embed"
"strconv"
"strings"
@@ -211,9 +212,9 @@ type FirewallRule struct {
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator()
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
@@ -224,8 +225,8 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma
continue
}
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
@@ -254,7 +255,7 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
@@ -262,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
all, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &nbgroup.Group{}
}
@@ -313,11 +314,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
}
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -341,11 +342,11 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@@ -353,7 +354,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
exists := am.savePolicy(account, policy)
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
@@ -361,19 +362,19 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
if exists {
action = activity.PolicyUpdated
}
am.StoreEvent(userID, policy.ID, accountID, action, policy.EventMeta())
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@@ -384,23 +385,23 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@@ -490,7 +491,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
@@ -506,7 +507,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
}
// validate the peer based on policy posture checks applied
isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID)
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
@@ -527,7 +528,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
}
// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool {
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
if !ok && peer == nil {
return false
@@ -540,9 +541,9 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe
}
for _, check := range postureChecks.GetChecks() {
isValid, err := check.Check(*peer)
isValid, err := check.Check(ctx, *peer)
if err != nil {
log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
}
if !isValid {
return false