Compare commits

...

4 Commits

Author SHA1 Message Date
Viktor Liu
063cbdc6d8 Enable trace logging in WASM client 2026-01-16 15:06:58 +08:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
10 changed files with 368 additions and 51 deletions

View File

@@ -10,6 +10,7 @@ import (
"net/netip"
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
@@ -30,6 +31,11 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
const (
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
)
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -258,18 +264,40 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
// With lazy connections, the connection is established on first traffic.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
// Check context status upfront
if ctx.Err() != nil {
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
return nil, ctx.Err()
}
engine, err := c.getEngine()
if err != nil {
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
return nil, err
}
nsnet, err := engine.GetNet()
if err != nil {
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when DialContext is called. The netstack
// dial triggers WireGuard traffic which activates the lazy connection.
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
conn, err := nsnet.DialContext(ctx, network, address)
if err != nil {
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
return nil, err
}
logrus.Infof("embed.Dial: successfully connected to %s", address)
return conn, nil
}
// DialContext dials a network address in the netbird network with context
@@ -368,6 +396,35 @@ func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
return syncResp, nil
}
// WaitForPeerConnection waits for a peer with the given IP to be connected.
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
logrus.Infof("Waiting for peer %s to be connected", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
case <-ticker.C:
status, err := c.Status()
if err != nil {
logrus.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, p := range status.Peers {
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
return nil
}
}
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
@@ -381,9 +438,8 @@ func (c *Client) SetLogLevel(levelStr string) error {
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
// Note: ConnectClient doesn't have SetLogLevel method
_ = connect
return nil
}

View File

@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
}
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugf("dialing %s %s", network, addr)
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
conn, err := d.net.Dial(network, addr)
if err != nil {
log.Debugf("failed to deal connection: %s", err)
log.Warnf("NSDialer.Dial failed: %s", err)
}
return conn, err
}

View File

@@ -16,8 +16,8 @@ import (
const (
PriorityMgmtCache = 150
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityDNSRoute = 100
PriorityLocal = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100

View File

@@ -185,7 +185,7 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
@@ -226,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -256,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -278,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -304,7 +304,7 @@ func TestUpdateDNSServer(t *testing.T) {
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
@@ -320,7 +320,7 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
@@ -2052,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -23,11 +23,13 @@ import (
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
icmpEchoRequest = 8
icmpCodeEcho = 0
@@ -169,6 +171,9 @@ func createSSHMethod(client *netbird.Client) js.Func {
}
return createPromise(func(resolve, reject js.Value) {
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when Dial is called in ssh.Connect().
sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
@@ -399,20 +404,79 @@ func createStatusSummaryMethod(client *netbird.Client) js.Func {
// createStatusDetailMethod creates the statusDetail method
func createStatusDetailMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Info("statusDetail called")
return createPromise(func(resolve, reject js.Value) {
log.Info("statusDetail: getting overview")
overview, err := getStatusOverview(client)
if err != nil {
log.Errorf("statusDetail: getStatusOverview failed: %v", err)
reject.Invoke(js.ValueOf(err.Error()))
return
}
log.Info("statusDetail: generating full detail summary")
detail := overview.FullDetailSummary()
log.Infof("statusDetail: detail length=%d", len(detail))
js.Global().Get("console").Call("log", detail)
resolve.Invoke(js.Undefined())
})
})
}
// createWaitForPeerConnectionMethod creates a method that waits for a peer to be connected
func createWaitForPeerConnectionMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
if len(args) < 1 {
reject.Invoke(js.ValueOf("peer IP address required"))
return
}
peerIP := args[0].String()
timeoutMs := int(defaultPeerConnectionTimeout.Milliseconds())
if len(args) > 1 && !args[1].IsUndefined() && !args[1].IsNull() {
timeoutMs = args[1].Int()
}
timeout := time.Duration(timeoutMs) * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
log.Infof("Waiting for peer %s to be connected (timeout: %v)", peerIP, timeout)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
reject.Invoke(js.ValueOf(fmt.Sprintf("timeout waiting for peer %s to connect", peerIP)))
return
case <-ticker.C:
overview, err := getStatusOverview(client)
if err != nil {
log.Debugf("Error getting status: %v", err)
continue
}
for _, peer := range overview.Peers.Details {
if peer.IP == peerIP && peer.Status == "Connected" {
log.Infof("Peer %s is now connected (type: %s)", peerIP, peer.ConnType)
resolve.Invoke(js.ValueOf(map[string]interface{}{
"ip": peer.IP,
"status": peer.Status,
"connType": peer.ConnType,
}))
return
}
}
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
})
})
}
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
@@ -471,12 +535,51 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
resolve := promiseArgs[0]
reject := promiseArgs[1]
go handler(resolve, reject)
// Wrap reject to always log the error
loggingReject := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) > 0 {
log.Errorf("Promise rejected: %v", args[0])
js.Global().Get("console").Call("error", "WASM Promise rejected:", args[0])
}
reject.Invoke(args[0])
return nil
})
go handler(resolve, loggingReject.Value)
return nil
}))
}
// waitForPeerConnected waits for a peer with the given IP to be connected
func waitForPeerConnected(ctx context.Context, client *netbird.Client, peerIP string) error {
log.Infof("Waiting for peer %s to be connected before operation", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect", peerIP)
case <-ticker.C:
overview, err := getStatusOverview(client)
if err != nil {
log.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, peer := range overview.Peers.Details {
if peer.IP == peerIP && peer.Status == "Connected" {
log.Infof("Peer %s is now connected (type: %s), proceeding with operation", peerIP, peer.ConnType)
return nil
}
}
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// createDetectSSHServerMethod creates the SSH server detection method
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -499,6 +602,10 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when Dial is called. Waiting would cause
// a deadlock since lazy connections only open on traffic.
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
if err != nil {
reject.Invoke(err.Error())
@@ -525,6 +632,7 @@ func createClientObject(client *netbird.Client) js.Value {
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
obj["waitForPeerConnection"] = createWaitForPeerConnectionMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client)
@@ -549,7 +657,10 @@ func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return
}
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
// Force trace logging for debugging - must be set BEFORE netbird.New
// as New() will call logrus.SetLevel with options.LogLevel
options.LogLevel = "trace"
if err := util.InitLog("trace", util.LogConsole); err != nil {
log.Warnf("Failed to initialize logging: %v", err)
}

View File

@@ -26,6 +26,8 @@ type UserDataCache interface {
Get(ctx context.Context, key string) (*idp.UserData, error)
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
Delete(ctx context.Context, key string) error
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
}
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
return u.cache.Delete(ctx, key)
}
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
var users []*idp.UserData
v, err := u.cache.Get(ctx, key, &users)
if err != nil {
return nil, err
}
switch v := v.(type) {
case []*idp.UserData:
return v, nil
case *[]*idp.UserData:
return *v, nil
case []byte:
return unmarshalUserData(v)
}
return nil, fmt.Errorf("unexpected type: %T", v)
}
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
}
// NewUserDataCache creates a new UserDataCacheImpl object.
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
simpleCache := cache.New[any](store)

View File

@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil {
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
func isTerraformRequest(r *http.Request) bool {
ua := strings.ToLower(r.Header.Get("User-Agent"))
return strings.Contains(ua, "terraform")
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {

View File

@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test various Terraform user agent formats
terraformUserAgents := []string{
"Terraform/1.5.0",
"terraform/1.0.0",
"Terraform-Provider/2.0.0",
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
}
for _, userAgent := range terraformUserAgents {
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", userAgent)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
})
}
})
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {

View File

@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{}
switch {
case allowed:
start := time.Now()
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
case user != nil && user.AccountID == accountID:
accountUsers = append(accountUsers, user)
default:
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0)
filtered := make(map[string]*idp.UserData, len(accountUsers))
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
start := time.Now()
integrationKeys := make(map[string]struct{})
for _, user := range accountUsers {
if user.Issued == types.UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = true
continue
}
usersFromIntegration = append(usersFromIntegration, info)
integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
continue
}
if !user.IsServiceUser {
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
}
}
for key := range integrationKeys {
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
continue
}
for _, ud := range usersData {
filtered[ud.ID] = ud
}
}
for _, ud := range filtered {
usersFromIntegration = append(usersFromIntegration, ud)
}
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
start = time.Now()
queriedUsers, err = am.lookupCache(ctx, users, accountID)
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
if err != nil {
return nil, err
}

View File

@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager()
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute)
tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
assert.NoError(t, err)
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
assert.NoError(t, err)
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)