mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 07:34:55 -04:00
Compare commits
7 Commits
poc/netsta
...
debug-keyc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cf6ece217 | ||
|
|
86340fb684 | ||
|
|
52e86cabc3 | ||
|
|
f73a2e2848 | ||
|
|
19fa071a93 | ||
|
|
cba3c549e9 | ||
|
|
65247de48d |
@@ -51,7 +51,7 @@ var loginCmd = &cobra.Command{
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
@@ -151,13 +151,21 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
)
|
||||
|
||||
@@ -94,7 +95,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
|
||||
@@ -66,13 +66,15 @@ type statusOutputOverview struct {
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
prefixNamesFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -83,12 +85,14 @@ var statusCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
@@ -172,8 +176,12 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
}
|
||||
@@ -185,11 +193,26 @@ func parseFilters() error {
|
||||
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
|
||||
}
|
||||
ipsFilterMap[addr] = struct{}{}
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for _, name := range prefixNamesFilter {
|
||||
prefixNamesFilterMap[strings.ToLower(name)] = struct{}{}
|
||||
}
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableDetailFlagWhenFilterFlag() {
|
||||
if !detailFlag && !jsonFlag && !yamlFlag {
|
||||
detailFlag = true
|
||||
}
|
||||
}
|
||||
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
@@ -415,6 +438,7 @@ func parsePeers(peers peersStateOutput) string {
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -431,5 +455,15 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
ipEval = true
|
||||
}
|
||||
}
|
||||
return statusEval || ipEval
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
@@ -85,7 +85,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
|
||||
@@ -215,12 +215,9 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
if *input.PreSharedKey != "" {
|
||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
log.Infof("new pre-shared key provided, replacing old key")
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
@@ -60,22 +61,7 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 4: new empty pre-shared key config -> fetch it
|
||||
newPreSharedKey := ""
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &newPreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 5: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
|
||||
@@ -17,11 +17,12 @@ import (
|
||||
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/netbirdio/management-integrations/additions"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/additions"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -66,6 +67,7 @@ type AccountManager interface {
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
DeleteAccount(accountID, userID string) error
|
||||
MarkPATUsed(tokenID string) error
|
||||
@@ -1697,6 +1699,39 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
return am.dnsDomain
|
||||
}
|
||||
|
||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and set the list of groups with access permissions.
|
||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exists
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
@@ -1768,3 +1803,15 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
||||
}
|
||||
return acc
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
for _, allowedGroup := range allowedGroups {
|
||||
if userGroup == allowedGroup {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -220,6 +220,10 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
|
||||
return "", status.Errorf(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
@@ -312,7 +316,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||
if err != nil {
|
||||
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
|
||||
return nil, mapError(err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var sshKey []byte
|
||||
|
||||
@@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
||||
accountManager.GetAccountFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.GetAccountFromToken,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
claimsExtractor,
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
|
||||
@@ -26,18 +26,18 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(token string) error
|
||||
|
||||
// GetAccountFromTokenFunc function
|
||||
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
// CheckUserAccessByJWTGroupsFunc function
|
||||
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
getAccountFromToken GetAccountFromTokenFunc
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
audience string
|
||||
userIDClaim string
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
audience string
|
||||
userIDClaim string
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -46,20 +46,20 @@ const (
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
userIdClaim = jwtclaims.UserIDClaim
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
getAccountFromToken: getAccountFromToken,
|
||||
claimsExtractor: claimsExtractor,
|
||||
audience: audience,
|
||||
userIDClaim: userIdClaim,
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
claimsExtractor: claimsExtractor,
|
||||
audience: audience,
|
||||
userIDClaim: userIdClaim,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,34 +134,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
account, _, err := m.getAccountFromToken(authClaims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get the account from token: %w", err)
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.checkUserAccessByJWTGroups(authClaims)
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
@@ -217,15 +190,3 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
for _, allowedGroup := range allowedGroups {
|
||||
if userGroup == allowedGroup {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -73,17 +73,16 @@ func mockMarkPATUsed(token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
}
|
||||
|
||||
user, ok := testAccount.Users[claims.UserId]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
|
||||
if _, ok := testAccount.Users[claims.UserId]; !ok {
|
||||
return fmt.Errorf("user with id %s does not exist", claims.UserId)
|
||||
}
|
||||
|
||||
return testAccount, user, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
@@ -137,7 +136,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
mockGetAccountFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockGetAccountFromToken,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
claimsExtractor,
|
||||
audience,
|
||||
userIDClaim,
|
||||
|
||||
@@ -62,7 +62,7 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
@@ -354,13 +354,13 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//totalUsers, err := km.totalUsersCount()
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
q.Add("max", fmt.Sprint(200))
|
||||
|
||||
body, err := km.get("users", q)
|
||||
if err != nil {
|
||||
@@ -409,12 +409,19 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
log.Infof("Link header: %v", resp.Header.Get("Link"))
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount() (*int, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.Infof("Keycloak totalUsersCount took %d ms to handle", time.Since(start).Milliseconds())
|
||||
}()
|
||||
|
||||
body, err := km.get("users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -69,6 +69,7 @@ type MockAccountManager struct {
|
||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any)
|
||||
@@ -543,6 +544,13 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
if am.CheckUserAccessByJWTGroupsFunc != nil {
|
||||
return am.CheckUserAccessByJWTGroupsFunc(claims)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
|
||||
}
|
||||
|
||||
// GetPeers mocks GetPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.GetPeersFunc != nil {
|
||||
|
||||
@@ -493,7 +493,11 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]
|
||||
|
||||
for _, p := range group.Peers {
|
||||
peer, ok := account.Peers[p]
|
||||
if ok && peer != nil && peer.ID == peerID {
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user