mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
[client, management] Feature/ssh fine grained access (#4969)
Add fine-grained SSH access control with authorized users/groups
This commit is contained in:
@@ -1151,6 +1151,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
|
||||
@@ -11,15 +11,18 @@ import (
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
type sshServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetStatus() (bool, []sshserver.SessionInfo)
|
||||
UpdateSSHAuth(config *sshauth.Config)
|
||||
}
|
||||
|
||||
func (e *Engine) setupSSHPortRedirection() error {
|
||||
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
|
||||
|
||||
return sshServer.GetStatus()
|
||||
}
|
||||
|
||||
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
|
||||
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
|
||||
if sshAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if e.sshServer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := sshAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range sshAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
// Update SSH server with new authorization configuration
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
}
|
||||
|
||||
e.sshServer.UpdateSSHAuth(authConfig)
|
||||
}
|
||||
|
||||
184
client/ssh/auth/auth.go
Normal file
184
client/ssh/auth/auth.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
type Authorizer struct {
|
||||
// UserIDClaim is the JWT claim to extract the user ID from
|
||||
userIDClaim string
|
||||
|
||||
// authorizedUsers is a list of hashed user IDs authorized to access this peer
|
||||
authorizedUsers []sshuserhash.UserIDHash
|
||||
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config contains configuration for the SSH authorizer
|
||||
type Config struct {
|
||||
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
|
||||
UserIDClaim string
|
||||
|
||||
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
|
||||
AuthorizedUsers []sshuserhash.UserIDHash
|
||||
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// Update updates the authorizer configuration with new values
|
||||
func (a *Authorizer) Update(config *Config) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if config == nil {
|
||||
// Clear authorization
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
|
||||
userIDClaim := config.UserIDClaim
|
||||
if userIDClaim == "" {
|
||||
userIDClaim = DefaultUserIDClaim
|
||||
}
|
||||
a.userIDClaim = userIDClaim
|
||||
|
||||
// Store authorized users list
|
||||
a.authorizedUsers = config.AuthorizedUsers
|
||||
|
||||
// Store machine users mapping
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range config.MachineUsers {
|
||||
if len(indexes) > 0 {
|
||||
machineUsers[osUser] = indexes
|
||||
}
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user
|
||||
// Returns nil if authorized, or an error describing why authorization failed
|
||||
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
||||
if jwtUserID == "" {
|
||||
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
|
||||
return ErrEmptyUserID
|
||||
}
|
||||
|
||||
// Hash the JWT user ID for comparison
|
||||
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
|
||||
if err != nil {
|
||||
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
|
||||
return fmt.Errorf("failed to hash user ID: %w", err)
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
// Find the index of this user in the authorized list
|
||||
userIndex, found := a.findUserIndex(hashedUserID)
|
||||
if !found {
|
||||
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
|
||||
return ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
|
||||
}
|
||||
|
||||
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
|
||||
// Checks wildcard mapping first, then specific OS user mappings
|
||||
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
|
||||
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
|
||||
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
|
||||
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
|
||||
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific OS username mapping
|
||||
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
|
||||
if !hasMachineUserMapping {
|
||||
// No mapping for this OS user - deny by default (fail closed)
|
||||
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
|
||||
return ErrNoMachineUserMapping
|
||||
}
|
||||
|
||||
// Check if user's index is in the allowed indexes for this specific OS user
|
||||
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
|
||||
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
|
||||
return ErrUserNotMappedToOSUser
|
||||
}
|
||||
|
||||
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserIDClaim returns the JWT claim name used to extract user IDs
|
||||
func (a *Authorizer) GetUserIDClaim() string {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
for i, id := range a.authorizedUsers {
|
||||
if id == hashedUserID {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// isIndexInList checks if an index exists in a list of indexes
|
||||
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
|
||||
for _, idx := range indexes {
|
||||
if idx == index {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
612
client/ssh/auth/auth_test.go
Normal file
612
client/ssh/auth/auth_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list with one user
|
||||
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
|
||||
MachineUsers: map[string][]uint32{},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Try to authorize a different user
|
||||
err = authorizer.Authorize("unauthorized-user", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// All attempts should fail when no machine user mappings exist (fail closed)
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
user3Hash, err := sshauth.HashUserID("user3")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 1}, // user1 and user2 can access root
|
||||
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||
"admin": {0}, // only user1 can access admin
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 (index 0) should access root and admin
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user1", "admin")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user2 (index 1) should access root and postgres
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user3 (index 2) should access postgres
|
||||
err = authorizer.Authorize("user3", "postgres")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
user3Hash, err := sshauth.HashUserID("user3")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 1}, // user1 and user2 can access root
|
||||
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||
"admin": {0}, // only user1 can access admin
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 (index 0) should NOT access postgres
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user2 (index 1) should NOT access admin
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user3 (index 2) should NOT access root
|
||||
err = authorizer.Authorize("user3", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user3 (index 2) should NOT access admin
|
||||
err = authorizer.Authorize("user3", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0}, // only root is mapped
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 should NOT access an unmapped OS user (fail closed)
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users list
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Empty user ID should fail
|
||||
err = authorizer.Authorize("", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrEmptyUserID)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up multiple authorized users
|
||||
userHashes := make([]sshauth.UserIDHash, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
|
||||
require.NoError(t, err)
|
||||
userHashes[i] = hash
|
||||
}
|
||||
|
||||
// Create machine user mapping for all users
|
||||
rootIndexes := make([]uint32, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
rootIndexes[i] = uint32(i)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: userHashes,
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": rootIndexes,
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// All users should be authorized for root
|
||||
for i := 0; i < 10; i++ {
|
||||
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
|
||||
assert.NoError(t, err, "user%d should be authorized", i)
|
||||
}
|
||||
|
||||
// User not in list should fail
|
||||
err := authorizer.Authorize("unknown-user", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up initial configuration
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{"root": {0}},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 should be authorized
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Clear configuration
|
||||
authorizer.Update(nil)
|
||||
|
||||
// user1 should no longer be authorized
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Machine users with empty index lists should be filtered out
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0},
|
||||
"postgres": {}, // empty list - should be filtered out
|
||||
"admin": nil, // nil list - should be filtered out
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// root should work
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// postgres should fail (no mapping)
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
// admin should fail (no mapping)
|
||||
err = authorizer.Authorize("user1", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up with custom user ID claim
|
||||
user1Hash, err := sshauth.HashUserID("user@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: "email",
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0},
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Verify the custom claim is set
|
||||
assert.Equal(t, "email", authorizer.GetUserIDClaim())
|
||||
|
||||
// Authorize with email as user ID
|
||||
err = authorizer.Authorize("user@example.com", "root")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Verify default claim
|
||||
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
|
||||
|
||||
// Set up with empty user ID claim (should use default)
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: "", // empty - should use default
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Should fall back to default
|
||||
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||
}
|
||||
|
||||
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Create a large authorized users list
|
||||
const numUsers = 1000
|
||||
userHashes := make([]sshauth.UserIDHash, numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
hash, err := sshauth.HashUserID("user" + string(rune(i)))
|
||||
require.NoError(t, err)
|
||||
userHashes[i] = hash
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: userHashes,
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 500, 999}, // first, middle, and last user
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// First user should have access
|
||||
err := authorizer.Authorize("user"+string(rune(0)), "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Middle user should have access
|
||||
err = authorizer.Authorize("user"+string(rune(500)), "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Last user should have access
|
||||
err = authorizer.Authorize("user"+string(rune(999)), "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// User not in mapping should NOT have access
|
||||
err = authorizer.Authorize("user"+string(rune(100)), "root")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Set up authorized users
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0, 1},
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Test concurrent authorization calls (should be safe to read concurrently)
|
||||
const numGoroutines = 100
|
||||
errChan := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(idx int) {
|
||||
user := "user1"
|
||||
if idx%2 == 0 {
|
||||
user = "user2"
|
||||
}
|
||||
err := authorizer.Authorize(user, "root")
|
||||
errChan <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete and collect errors
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errChan
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
user3Hash, err := sshauth.HashUserID("user3")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with wildcard - all authorized users can access any OS user
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0, 1, 2}, // wildcard with all user indexes
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// All authorized users should be able to access any OS user
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user3", "admin")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user1", "ubuntu")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "nginx")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user3", "docker")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with wildcard
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0},
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 should have access
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Unauthorized user should still be denied even with wildcard
|
||||
err = authorizer.Authorize("unauthorized-user", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with both wildcard and specific mappings
|
||||
// Wildcard takes precedence for users in the wildcard index list
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0, 1}, // wildcard for both users
|
||||
"root": {0}, // specific mapping that would normally restrict to user1 only
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Both users should be able to access any other OS user via wildcard
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
user1Hash, err := sshauth.HashUserID("user1")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure WITHOUT wildcard - only specific mappings
|
||||
config := &Config{
|
||||
UserIDClaim: DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"root": {0}, // only user1
|
||||
"postgres": {1}, // only user2
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// user1 can access root
|
||||
err = authorizer.Authorize("user1", "root")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user2 can access postgres
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// user1 cannot access postgres
|
||||
err = authorizer.Authorize("user1", "postgres")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// user2 cannot access root
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||
|
||||
// Neither can access unmapped OS users
|
||||
err = authorizer.Authorize("user1", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user2", "admin")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
}
|
||||
|
||||
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
// This test covers the scenario where wildcard exists with limited indexes.
|
||||
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
|
||||
// Other users can only access OS users they are explicitly mapped to.
|
||||
authorizer := NewAuthorizer()
|
||||
|
||||
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
|
||||
wasmHash, err := sshauth.HashUserID("wasm")
|
||||
require.NoError(t, err)
|
||||
user2Hash, err := sshauth.HashUserID("user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Configure with wildcard having only index 0, and specific mappings for other OS users
|
||||
config := &Config{
|
||||
UserIDClaim: "sub",
|
||||
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
|
||||
"alice": {1}, // specific mapping for user2
|
||||
"bob": {1}, // specific mapping for user2
|
||||
},
|
||||
}
|
||||
authorizer.Update(config)
|
||||
|
||||
// wasm (index 0) should access any OS user via wildcard
|
||||
err = authorizer.Authorize("wasm", "root")
|
||||
assert.NoError(t, err, "wasm should access root via wildcard")
|
||||
|
||||
err = authorizer.Authorize("wasm", "alice")
|
||||
assert.NoError(t, err, "wasm should access alice via wildcard")
|
||||
|
||||
err = authorizer.Authorize("wasm", "bob")
|
||||
assert.NoError(t, err, "wasm should access bob via wildcard")
|
||||
|
||||
err = authorizer.Authorize("wasm", "postgres")
|
||||
assert.NoError(t, err, "wasm should access postgres via wildcard")
|
||||
|
||||
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
|
||||
err = authorizer.Authorize("user2", "alice")
|
||||
assert.NoError(t, err, "user2 should access alice via explicit mapping")
|
||||
|
||||
err = authorizer.Authorize("user2", "bob")
|
||||
assert.NoError(t, err, "user2 should access bob via explicit mapping")
|
||||
|
||||
err = authorizer.Authorize("user2", "root")
|
||||
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
err = authorizer.Authorize("user2", "postgres")
|
||||
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
|
||||
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||
|
||||
// Unauthorized user should still be denied
|
||||
err = authorizer.Authorize("user3", "root")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
@@ -27,9 +27,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
// Configure SSH authorization for the test user
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"sub": user,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
@@ -23,10 +23,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
|
||||
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get current OS username for machine user mapping
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
currentUser: {0}, // Allow test-user (index 0) to access current OS user
|
||||
},
|
||||
}
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -138,6 +139,8 @@ type Server struct {
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
|
||||
authorizer *sshauth.Authorizer
|
||||
|
||||
suSupportsPty bool
|
||||
loginIsUtilLinux bool
|
||||
}
|
||||
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||
// This should be called when network map updates include new SSH auth configuration
|
||||
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Reset JWT validator/extractor to pick up new userIDClaim
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
|
||||
s.authorizer.Update(config)
|
||||
}
|
||||
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
authorizer := s.authorizer
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
// Use custom userIDClaim from authorizer if available
|
||||
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
}
|
||||
if authorizer.GetUserIDClaim() != "" {
|
||||
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
|
||||
}
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(extractorOptions...)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
osUsername := ctx.User()
|
||||
remoteAddr := ctx.RemoteAddr()
|
||||
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
||||
s.mu.RLock()
|
||||
authorizer := s.authorizer
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
|
||||
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
|
||||
return false
|
||||
}
|
||||
|
||||
key := newAuthKey(osUsername, remoteAddr)
|
||||
s.mu.Lock()
|
||||
s.pendingAuthJWT[key] = userAuth.UserId
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
||||
if err != nil {
|
||||
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
|
||||
@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,10 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
@@ -16,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
|
||||
sshConfig := &proto.SSHConfig{
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||
}
|
||||
|
||||
if peer.SSHEnabled {
|
||||
if sshConfig.SshEnabled {
|
||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||
}
|
||||
|
||||
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
|
||||
@@ -635,7 +635,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
|
||||
@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
|
||||
@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
PortRanges: []types.RulePortRange{portRange},
|
||||
}},
|
||||
}
|
||||
if protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||
policy.Rules[0].AuthorizedUser = userAuth.UserId
|
||||
}
|
||||
|
||||
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
||||
if err != nil {
|
||||
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
LocalFlags: &api.PeerLocalFlags{
|
||||
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
if !approved {
|
||||
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
if osVersion == "" {
|
||||
osVersion = peer.Meta.Core
|
||||
}
|
||||
|
||||
return &api.PeerBatch{
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
LocalFlags: &api.PeerLocalFlags{
|
||||
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
pr.Protocol = types.PolicyRuleProtocolUDP
|
||||
case api.PolicyRuleUpdateProtocolIcmp:
|
||||
pr.Protocol = types.PolicyRuleProtocolICMP
|
||||
case api.PolicyRuleUpdateProtocolNetbirdSsh:
|
||||
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||
return
|
||||
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
}
|
||||
}
|
||||
|
||||
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
|
||||
for _, sourceGroupID := range pr.Sources {
|
||||
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
pr.AuthorizedGroups = *rule.AuthorizedGroups
|
||||
}
|
||||
|
||||
// validate policy object
|
||||
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
||||
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
||||
DestinationResource: r.DestinationResource.ToAPIResponse(),
|
||||
}
|
||||
|
||||
if len(r.AuthorizedGroups) != 0 {
|
||||
authorizedGroupsCopy := r.AuthorizedGroups
|
||||
rule.AuthorizedGroups = &authorizedGroupsCopy
|
||||
}
|
||||
|
||||
if len(r.Ports) != 0 {
|
||||
portsCopy := r.Ports
|
||||
rule.Ports = &portsCopy
|
||||
|
||||
@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
|
||||
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
|
||||
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peer.ID {
|
||||
return peer, nil
|
||||
|
||||
@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 8)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 1)
|
||||
assert.Contains(t, peers, account.Peers["peerI"])
|
||||
|
||||
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 7)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
||||
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
@@ -1910,16 +1910,16 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
||||
if len(policyIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
|
||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, policyIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
||||
var r types.PolicyRule
|
||||
var dest, destRes, sources, sourceRes, ports, portRanges []byte
|
||||
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
|
||||
var enabled, bidirectional sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
|
||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &r.AuthorizedUser)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
@@ -1945,6 +1945,9 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
||||
if portRanges != nil {
|
||||
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
||||
}
|
||||
if authorizedGroups != nil {
|
||||
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
|
||||
}
|
||||
}
|
||||
return &r, err
|
||||
})
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -45,8 +46,10 @@ const (
|
||||
|
||||
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
||||
nativeSSHPortString = "22022"
|
||||
nativeSSHPortNumber = 22022
|
||||
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
||||
defaultSSHPortString = "22"
|
||||
defaultSSHPortNumber = 22
|
||||
)
|
||||
|
||||
type supportedFeatures struct {
|
||||
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
peer := a.Peers[peerID]
|
||||
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
EnableSSH: enableSSH,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
|
||||
sshEnabled := false
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
|
||||
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
sshEnabled = true
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||
userIDs, ok := groupIDToUserIDs[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(localUsers) == 0 {
|
||||
localUsers = []string{auth.Wildcard}
|
||||
}
|
||||
|
||||
for _, localUser := range localUsers {
|
||||
if authorizedUsers[localUser] == nil {
|
||||
authorizedUsers[localUser] = make(map[string]struct{})
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
authorizedUsers[localUser][userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
case rule.AuthorizedUser != "":
|
||||
if authorizedUsers[auth.Wildcard] == nil {
|
||||
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||
}
|
||||
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
peers, fwRules := getAccumulatedResources()
|
||||
return peers, fwRules, authorizedUsers, sshEnabled
|
||||
}
|
||||
|
||||
func (a *Account) getAllowedUserIDs() map[string]struct{} {
|
||||
users := make(map[string]struct{})
|
||||
for _, nbUser := range a.Users {
|
||||
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
|
||||
users[nbUser.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
protocol := rule.Protocol
|
||||
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
protocol = PolicyRuleProtocolTCP
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
Protocol: string(protocol),
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
}
|
||||
}
|
||||
|
||||
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||
for _, pr := range portRanges {
|
||||
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portsIncludesSSH(ports []string) bool {
|
||||
for _, port := range ports {
|
||||
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
@@ -1660,6 +1743,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) GetActiveGroupUsers() map[string][]string {
|
||||
allGroupID := ""
|
||||
group, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get group all: %v", err)
|
||||
} else {
|
||||
allGroupID = group.ID
|
||||
}
|
||||
groups := make(map[string][]string, len(a.GroupsG))
|
||||
for _, user := range a.Users {
|
||||
if !user.IsBlocked() && !user.IsServiceUser {
|
||||
for _, groupID := range user.AutoGroups {
|
||||
groups[groupID] = append(groups[groupID], user.Id)
|
||||
}
|
||||
groups[allGroupID] = append(groups[allGroupID], user.Id)
|
||||
}
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||
@@ -1691,7 +1794,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
expanded = addNativeSSHRule(base, expanded)
|
||||
}
|
||||
|
||||
|
||||
@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetActiveGroupUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected map[string][]string
|
||||
}{
|
||||
{
|
||||
name: "all users are active",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user3"},
|
||||
"group2": {"user1", "user2"},
|
||||
"group3": {"user2"},
|
||||
"": {"user1", "user2", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some users are blocked",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2", "group3"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user3"},
|
||||
"group2": {"user1"},
|
||||
"group3": {"user3"},
|
||||
"": {"user1", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all users are blocked",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2"},
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{},
|
||||
},
|
||||
{
|
||||
name: "user with no auto groups",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user2"},
|
||||
"": {"user1", "user2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty account",
|
||||
account: &Account{
|
||||
Users: map[string]*User{},
|
||||
},
|
||||
expected: map[string][]string{},
|
||||
},
|
||||
{
|
||||
name: "multiple users in same group",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user2", "user3"},
|
||||
"": {"user1", "user2", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user in multiple groups with blocked users",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1"},
|
||||
"group2": {"user1"},
|
||||
"group3": {"user1", "user3"},
|
||||
"": {"user1", "user3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetActiveGroupUsers()
|
||||
|
||||
// Check that the number of groups matches
|
||||
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
|
||||
|
||||
// Check each group's users
|
||||
for groupID, expectedUsers := range tt.expected {
|
||||
actualUsers, exists := result[groupID]
|
||||
assert.True(t, exists, "group %s should exist in result", groupID)
|
||||
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
|
||||
}
|
||||
|
||||
// Ensure no extra groups in result
|
||||
for groupID := range result {
|
||||
_, exists := tt.expected[groupID]
|
||||
assert.True(t, exists, "unexpected group %s in result", groupID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -38,6 +38,8 @@ type NetworkMap struct {
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
AuthorizedUsers map[string]map[string]struct{}
|
||||
EnableSSH bool
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
|
||||
@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -23,6 +23,8 @@ const (
|
||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||
// PolicyRuleProtocolICMP type of traffic
|
||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||
// PolicyRuleProtocolNetbirdSSH type of traffic
|
||||
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
|
||||
protocol = PolicyRuleProtocolUDP
|
||||
case "icmp":
|
||||
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
||||
case "netbird-ssh":
|
||||
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
|
||||
default:
|
||||
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
||||
}
|
||||
|
||||
@@ -80,6 +80,12 @@ type PolicyRule struct {
|
||||
|
||||
// PortRanges a list of port ranges.
|
||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||
|
||||
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
|
||||
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
|
||||
|
||||
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
|
||||
AuthorizedUser string
|
||||
}
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
Protocol: pm.Protocol,
|
||||
Ports: make([]string, len(pm.Ports)),
|
||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
|
||||
AuthorizedUser: pm.AuthorizedUser,
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
copy(rule.Ports, pm.Ports)
|
||||
copy(rule.PortRanges, pm.PortRanges)
|
||||
for k, v := range pm.AuthorizedGroups {
|
||||
rule.AuthorizedGroups[k] = make([]string, len(v))
|
||||
copy(rule.AuthorizedGroups[k], v)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
_, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
||||
}
|
||||
|
||||
if userHadPeers {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
updateAccountPeers = true
|
||||
|
||||
err = transaction.SaveUser(ctx, updatedUser)
|
||||
if err != nil {
|
||||
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled && updateAccountPeers {
|
||||
if updateAccountPeers {
|
||||
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
// Creating a new regular user should send peer update (as users are not filtered yet)
|
||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// updating user with no linked peers should not update account peers and not send peer update
|
||||
// updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
|
||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -488,6 +488,8 @@ components:
|
||||
description: Indicates whether the peer is ephemeral or not
|
||||
type: boolean
|
||||
example: false
|
||||
local_flags:
|
||||
$ref: '#/components/schemas/PeerLocalFlags'
|
||||
required:
|
||||
- city_name
|
||||
- connected
|
||||
@@ -514,6 +516,49 @@ components:
|
||||
- serial_number
|
||||
- extra_dns_labels
|
||||
- ephemeral
|
||||
PeerLocalFlags:
|
||||
type: object
|
||||
properties:
|
||||
rosenpass_enabled:
|
||||
description: Indicates whether Rosenpass is enabled on this peer
|
||||
type: boolean
|
||||
example: true
|
||||
rosenpass_permissive:
|
||||
description: Indicates whether Rosenpass is in permissive mode or not
|
||||
type: boolean
|
||||
example: false
|
||||
server_ssh_allowed:
|
||||
description: Indicates whether SSH access this peer is allowed or not
|
||||
type: boolean
|
||||
example: true
|
||||
disable_client_routes:
|
||||
description: Indicates whether client routes are disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
disable_server_routes:
|
||||
description: Indicates whether server routes are disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
disable_dns:
|
||||
description: Indicates whether DNS management is disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
disable_firewall:
|
||||
description: Indicates whether firewall management is disabled on this peer or not
|
||||
type: boolean
|
||||
example: false
|
||||
block_lan_access:
|
||||
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||
type: boolean
|
||||
example: false
|
||||
block_inbound:
|
||||
description: Indicates whether inbound traffic is blocked on this peer
|
||||
type: boolean
|
||||
example: false
|
||||
lazy_connection_enabled:
|
||||
description: Indicates whether lazy connection is enabled on this peer
|
||||
type: boolean
|
||||
example: false
|
||||
PeerTemporaryAccessRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -936,7 +981,7 @@ components:
|
||||
protocol:
|
||||
description: Policy rule type of the traffic
|
||||
type: string
|
||||
enum: ["all", "tcp", "udp", "icmp"]
|
||||
enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
|
||||
example: "tcp"
|
||||
ports:
|
||||
description: Policy rule affected ports
|
||||
@@ -949,6 +994,14 @@ components:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/RulePortRange'
|
||||
authorized_groups:
|
||||
description: Map of user group ids to a list of local users
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "group1"
|
||||
required:
|
||||
- name
|
||||
- enabled
|
||||
|
||||
@@ -130,10 +130,11 @@ const (
|
||||
|
||||
// Defines values for PolicyRuleProtocol.
|
||||
const (
|
||||
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
||||
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
||||
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
||||
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
||||
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
||||
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
||||
PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
|
||||
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
||||
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
||||
)
|
||||
|
||||
// Defines values for PolicyRuleMinimumAction.
|
||||
@@ -144,10 +145,11 @@ const (
|
||||
|
||||
// Defines values for PolicyRuleMinimumProtocol.
|
||||
const (
|
||||
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
||||
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
||||
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
||||
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
||||
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
||||
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
||||
PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
|
||||
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
||||
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
||||
)
|
||||
|
||||
// Defines values for PolicyRuleUpdateAction.
|
||||
@@ -158,10 +160,11 @@ const (
|
||||
|
||||
// Defines values for PolicyRuleUpdateProtocol.
|
||||
const (
|
||||
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
||||
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
||||
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
||||
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
||||
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
||||
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
||||
PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
|
||||
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
||||
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
||||
)
|
||||
|
||||
// Defines values for ResourceType.
|
||||
@@ -1077,7 +1080,8 @@ type Peer struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
|
||||
// LastSeen Last time peer connected to Netbird's management service
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||
|
||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
|
||||
// LastSeen Last time peer connected to Netbird's management service
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||
|
||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// PeerLocalFlags defines model for PeerLocalFlags.
|
||||
type PeerLocalFlags struct {
|
||||
// BlockInbound Indicates whether inbound traffic is blocked on this peer
|
||||
BlockInbound *bool `json:"block_inbound,omitempty"`
|
||||
|
||||
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
|
||||
|
||||
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
|
||||
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
|
||||
|
||||
// DisableDns Indicates whether DNS management is disabled on this peer or not
|
||||
DisableDns *bool `json:"disable_dns,omitempty"`
|
||||
|
||||
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
|
||||
DisableFirewall *bool `json:"disable_firewall,omitempty"`
|
||||
|
||||
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
|
||||
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
|
||||
|
||||
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
|
||||
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||
|
||||
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
|
||||
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
|
||||
|
||||
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
|
||||
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
|
||||
|
||||
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
|
||||
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
|
||||
}
|
||||
|
||||
// PeerMinimum defines model for PeerMinimum.
|
||||
type PeerMinimum struct {
|
||||
// Id Peer ID
|
||||
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
|
||||
// Action Policy rule accept or drops packets
|
||||
Action PolicyRuleAction `json:"action"`
|
||||
|
||||
// AuthorizedGroups Map of user group ids to a list of local users
|
||||
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||
|
||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||
Bidirectional bool `json:"bidirectional"`
|
||||
|
||||
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
|
||||
// Action Policy rule accept or drops packets
|
||||
Action PolicyRuleMinimumAction `json:"action"`
|
||||
|
||||
// AuthorizedGroups Map of user group ids to a list of local users
|
||||
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||
|
||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||
Bidirectional bool `json:"bidirectional"`
|
||||
|
||||
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
|
||||
// Action Policy rule accept or drops packets
|
||||
Action PolicyRuleUpdateAction `json:"action"`
|
||||
|
||||
// AuthorizedGroups Map of user group ids to a list of local users
|
||||
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||
|
||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||
Bidirectional bool `json:"bidirectional"`
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -332,6 +332,24 @@ message NetworkMap {
|
||||
bool routesFirewallRulesIsEmpty = 11;
|
||||
|
||||
repeated ForwardingRule forwardingRules = 12;
|
||||
|
||||
// SSHAuth represents SSH authorization configuration
|
||||
SSHAuth sshAuth = 13;
|
||||
}
|
||||
|
||||
message SSHAuth {
|
||||
// UserIDClaim is the JWT claim to be used to get the users ID
|
||||
string UserIDClaim = 1;
|
||||
|
||||
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
|
||||
repeated bytes AuthorizedUsers = 2;
|
||||
|
||||
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
|
||||
map<string, MachineUserIndexes> machine_users = 3;
|
||||
}
|
||||
|
||||
message MachineUserIndexes {
|
||||
repeated uint32 indexes = 1;
|
||||
}
|
||||
|
||||
// RemotePeerConfig represents a configuration of a remote peer.
|
||||
|
||||
28
shared/sshauth/userhash.go
Normal file
28
shared/sshauth/userhash.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package sshauth
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"golang.org/x/crypto/blake2b"
|
||||
)
|
||||
|
||||
// UserIDHash represents a hashed user ID (BLAKE2b-128)
|
||||
type UserIDHash [16]byte
|
||||
|
||||
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
|
||||
// This function must produce the same hash on both client and management server
|
||||
func HashUserID(userID string) (UserIDHash, error) {
|
||||
hash, err := blake2b.New(16, nil)
|
||||
if err != nil {
|
||||
return UserIDHash{}, err
|
||||
}
|
||||
hash.Write([]byte(userID))
|
||||
var result UserIDHash
|
||||
copy(result[:], hash.Sum(nil))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// String returns the hexadecimal string representation of the hash
|
||||
func (h UserIDHash) String() string {
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
210
shared/sshauth/userhash_test.go
Normal file
210
shared/sshauth/userhash_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package sshauth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashUserID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
}{
|
||||
{
|
||||
name: "simple user ID",
|
||||
userID: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "UUID format",
|
||||
userID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
},
|
||||
{
|
||||
name: "numeric ID",
|
||||
userID: "12345",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
userID: "",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
userID: "user+test@domain.com",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
userID: "用户@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := HashUserID(tt.userID)
|
||||
if err != nil {
|
||||
t.Errorf("HashUserID() error = %v, want nil", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify hash is non-zero for non-empty inputs
|
||||
if tt.userID != "" && hash == [16]byte{} {
|
||||
t.Errorf("HashUserID() returned zero hash for non-empty input")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashUserID_Consistency(t *testing.T) {
|
||||
userID := "test@example.com"
|
||||
|
||||
hash1, err1 := HashUserID(userID)
|
||||
if err1 != nil {
|
||||
t.Fatalf("First HashUserID() error = %v", err1)
|
||||
}
|
||||
|
||||
hash2, err2 := HashUserID(userID)
|
||||
if err2 != nil {
|
||||
t.Fatalf("Second HashUserID() error = %v", err2)
|
||||
}
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashUserID_Uniqueness(t *testing.T) {
|
||||
tests := []struct {
|
||||
userID1 string
|
||||
userID2 string
|
||||
}{
|
||||
{"user1@example.com", "user2@example.com"},
|
||||
{"alice@domain.com", "bob@domain.com"},
|
||||
{"test", "test1"},
|
||||
{"", "a"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
hash1, err1 := HashUserID(tt.userID1)
|
||||
if err1 != nil {
|
||||
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
|
||||
}
|
||||
|
||||
hash2, err2 := HashUserID(tt.userID2)
|
||||
if err2 != nil {
|
||||
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIDHash_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hash UserIDHash
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero hash",
|
||||
hash: [16]byte{},
|
||||
expected: "00000000000000000000000000000000",
|
||||
},
|
||||
{
|
||||
name: "small value",
|
||||
hash: [16]byte{15: 0xff},
|
||||
expected: "000000000000000000000000000000ff",
|
||||
},
|
||||
{
|
||||
name: "large value",
|
||||
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
|
||||
expected: "0000000000000000deadbeefcafebabe",
|
||||
},
|
||||
{
|
||||
name: "max value",
|
||||
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
expected: "ffffffffffffffffffffffffffffffff",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.hash.String()
|
||||
if result != tt.expected {
|
||||
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIDHash_String_Length(t *testing.T) {
|
||||
// Test that String() always returns 32 hex characters (16 bytes * 2)
|
||||
userID := "test@example.com"
|
||||
hash, err := HashUserID(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("HashUserID() error = %v", err)
|
||||
}
|
||||
|
||||
result := hash.String()
|
||||
if len(result) != 32 {
|
||||
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
|
||||
}
|
||||
|
||||
// Verify it's valid hex
|
||||
for i, c := range result {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashUserID_KnownValues(t *testing.T) {
|
||||
// Test with known BLAKE2b-128 values to ensure correct implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
expected UserIDHash
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
userID: "",
|
||||
// BLAKE2b-128 of empty string
|
||||
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
|
||||
},
|
||||
{
|
||||
name: "single character 'a'",
|
||||
userID: "a",
|
||||
// BLAKE2b-128 of "a"
|
||||
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := HashUserID(tt.userID)
|
||||
if err != nil {
|
||||
t.Errorf("HashUserID() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if hash != tt.expected {
|
||||
t.Errorf("HashUserID(%q) = %x, want %x",
|
||||
tt.userID, hash, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashUserID(b *testing.B) {
|
||||
userID := "user@example.com"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = HashUserID(userID)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUserIDHash_String(b *testing.B) {
|
||||
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = hash.String()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user