[client, management] Feature/ssh fine grained access (#4969)

Add fine-grained SSH access control with authorized users/groups
This commit is contained in:
Zoltan Papp
2025-12-29 12:50:41 +01:00
committed by GitHub
parent 73201c4f3e
commit 67f7b2404e
32 changed files with 2345 additions and 512 deletions

View File

@@ -1151,6 +1151,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store

View File

@@ -11,15 +11,18 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
UpdateSSHAuth(config *sshauth.Config)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus()
}
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil {
return
}
if e.sshServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
// Update SSH server with new authorization configuration
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.sshServer.UpdateSSHAuth(authConfig)
}

184
client/ssh/auth/auth.go Normal file
View File

@@ -0,0 +1,184 @@
package auth
import (
"errors"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
)
var (
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
)
// Authorizer handles SSH fine-grained access control authorization
type Authorizer struct {
// UserIDClaim is the JWT claim to extract the user ID from
userIDClaim string
// authorizedUsers is a list of hashed user IDs authorized to access this peer
authorizedUsers []sshuserhash.UserIDHash
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// mu protects the list of users
mu sync.RWMutex
}
// Config contains configuration for the SSH authorizer
type Config struct {
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
UserIDClaim string
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
AuthorizedUsers []sshuserhash.UserIDHash
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
}
// Update updates the authorizer configuration with new values
func (a *Authorizer) Update(config *Config) {
a.mu.Lock()
defer a.mu.Unlock()
if config == nil {
// Clear authorization
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
log.Info("SSH authorization cleared")
return
}
userIDClaim := config.UserIDClaim
if userIDClaim == "" {
userIDClaim = DefaultUserIDClaim
}
a.userIDClaim = userIDClaim
// Store authorized users list
a.authorizedUsers = config.AuthorizedUsers
// Store machine users mapping
machineUsers := make(map[string][]uint32)
for osUser, indexes := range config.MachineUsers {
if len(indexes) > 0 {
machineUsers[osUser] = indexes
}
}
a.machineUsers = machineUsers
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user
// Returns nil if authorized, or an error describing why authorization failed
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
if jwtUserID == "" {
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
return ErrEmptyUserID
}
// Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil {
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
return fmt.Errorf("failed to hash user ID: %w", err)
}
a.mu.RLock()
defer a.mu.RUnlock()
// Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID)
if !found {
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
return ErrUserNotAuthorized
}
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
}
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
}
// Check for specific OS username mapping
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed)
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
return ErrNoMachineUserMapping
}
// Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
return ErrUserNotMappedToOSUser
}
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
// GetUserIDClaim returns the JWT claim name used to extract user IDs
func (a *Authorizer) GetUserIDClaim() string {
a.mu.RLock()
defer a.mu.RUnlock()
return a.userIDClaim
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
for i, id := range a.authorizedUsers {
if id == hashedUserID {
return i, true
}
}
return 0, false
}
// isIndexInList checks if an index exists in a list of indexes
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
for _, idx := range indexes {
if idx == index {
return true
}
}
return false
}

View File

@@ -0,0 +1,612 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/sshauth"
)
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list with one user
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Try to authorize a different user
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
}
authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed)
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should access root and admin
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err)
// user2 (index 1) should access root and postgres
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user3 (index 2) should access postgres
err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should NOT access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin
err = authorizer.Authorize("user3", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
}
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only root is mapped
},
}
authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Empty user ID should fail
err = authorizer.Authorize("", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID)
}
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up multiple authorized users
userHashes := make([]sshauth.UserIDHash, 10)
for i := 0; i < 10; i++ {
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
require.NoError(t, err)
userHashes[i] = hash
}
// Create machine user mapping for all users
rootIndexes := make([]uint32, 10)
for i := 0; i < 10; i++ {
rootIndexes[i] = uint32(i)
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": rootIndexes,
},
}
authorizer.Update(config)
// All users should be authorized for root
for i := 0; i < 10; i++ {
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i)
}
// User not in list should fail
err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer := NewAuthorizer()
// Set up initial configuration
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{"root": {0}},
}
authorizer.Update(config)
// user1 should be authorized
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Clear configuration
authorizer.Update(nil)
// user1 should no longer be authorized
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Machine users with empty index lists should be filtered out
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
"postgres": {}, // empty list - should be filtered out
"admin": nil, // nil list - should be filtered out
},
}
authorizer.Update(config)
// root should work
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// postgres should fail (no mapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping)
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Set up with custom user ID claim
user1Hash, err := sshauth.HashUserID("user@example.com")
require.NoError(t, err)
config := &Config{
UserIDClaim: "email",
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
},
}
authorizer.Update(config)
// Verify the custom claim is set
assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID
err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err)
}
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Verify default claim
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
// Set up with empty user ID claim (should use default)
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: "", // empty - should use default
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Should fall back to default
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
}
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer := NewAuthorizer()
// Create a large authorized users list
const numUsers = 1000
userHashes := make([]sshauth.UserIDHash, numUsers)
for i := 0; i < numUsers; i++ {
hash, err := sshauth.HashUserID("user" + string(rune(i)))
require.NoError(t, err)
userHashes[i] = hash
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": {0, 500, 999}, // first, middle, and last user
},
}
authorizer.Update(config)
// First user should have access
err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err)
// Middle user should have access
err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err)
// Last user should have access
err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err)
// User not in mapping should NOT have access
err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err)
}
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1},
},
}
authorizer.Update(config)
// Test concurrent authorization calls (should be safe to read concurrently)
const numGoroutines = 100
errChan := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(idx int) {
user := "user1"
if idx%2 == 0 {
user = "user2"
}
err := authorizer.Authorize(user, "root")
errChan <- err
}(i)
}
// Wait for all goroutines to complete and collect errors
for i := 0; i < numGoroutines; i++ {
err := <-errChan
assert.NoError(t, err)
}
}
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
// Configure with wildcard - all authorized users can access any OS user
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1, 2}, // wildcard with all user indexes
},
}
authorizer.Update(config)
// All authorized users should be able to access any OS user
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err)
}
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Configure with wildcard
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"*": {0},
},
}
authorizer.Update(config)
// user1 should have access
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with both wildcard and specific mappings
// Wildcard takes precedence for users in the wildcard index list
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1}, // wildcard for both users
"root": {0}, // specific mapping that would normally restrict to user1 only
},
}
authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard
err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err)
}
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure WITHOUT wildcard - only specific mappings
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only user1
"postgres": {1}, // only user2
},
}
authorizer.Update(config)
// user1 can access root
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// user2 can access postgres
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user1 cannot access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root
err = authorizer.Authorize("user2", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
// This test covers the scenario where wildcard exists with limited indexes.
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
// Other users can only access OS users they are explicitly mapped to.
authorizer := NewAuthorizer()
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
wasmHash, err := sshauth.HashUserID("wasm")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with wildcard having only index 0, and specific mappings for other OS users
config := &Config{
UserIDClaim: "sub",
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
"alice": {1}, // specific mapping for user2
"bob": {1}, // specific mapping for user2
},
}
authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard
err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard")
err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard")
err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard")
err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping")
err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping")
err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}

View File

@@ -27,9 +27,11 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}

View File

@@ -23,10 +23,12 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestJWTEnforcement(t *testing.T) {
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
tc.setupServer(server)
}
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
// Get current OS username for machine user mapping
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0}, // Allow test-user (index 0) to access current OS user
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())

View File

@@ -21,6 +21,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
@@ -138,6 +139,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
authorizer *sshauth.Authorizer
suSupportsPty bool
loginIsUtilLinux bool
}
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
}
return s
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Reset JWT validator/extractor to pick up new userIDClaim
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
return nil
}
config := s.jwtConfig
authorizer := s.authorizer
s.mu.RUnlock()
if config == nil {
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
true,
)
extractor := jwt.NewClaimsExtractor(
// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
)
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
}
extractor := jwt.NewClaimsExtractor(extractorOptions...)
s.mu.Lock()
defer s.mu.Unlock()
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr()
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false
}
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
s.mu.RLock()
authorizer := s.authorizer
s.mu.RUnlock()
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
return false
}
key := newAuthKey(osUsername, remoteAddr)
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
return true
}

View File

@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil {
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
}
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}

View File

@@ -6,7 +6,10 @@ import (
"net/url"
"strings"
log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
@@ -16,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/sshauth"
)
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
return nbConfig
}
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
sshConfig := &proto.SSHConfig{
SshEnabled: peer.SSHEnabled,
SshEnabled: peer.SSHEnabled || enableSSH,
}
if peer.SSHEnabled {
if sshConfig.SshEnabled {
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
}
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
},
Checks: toProtocolChecks(ctx, checks),
}
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.ForwardingRules = forwardingRules
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
return response
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{

View File

@@ -635,7 +635,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}

View File

@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
}
if settings.GroupsPropagationEnabled {
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
if err != nil {
return err
}
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
if err != nil {
return err
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil {
return err
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
return nil

View File

@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}

View File

@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
PortRanges: []types.RulePortRange{portRange},
}},
}
if protocol == types.PolicyRuleProtocolNetbirdSSH {
policy.Rules[0].AuthorizedUser = userAuth.UserId
}
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil {
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
LocalFlags: &api.PeerLocalFlags{
BlockInbound: &peer.Meta.Flags.BlockInbound,
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
DisableDns: &peer.Meta.Flags.DisableDNS,
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
},
}
if !approved {
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
if osVersion == "" {
osVersion = peer.Meta.Core
}
return &api.PeerBatch{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
LocalFlags: &api.PeerLocalFlags{
BlockInbound: &peer.Meta.Flags.BlockInbound,
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
DisableDns: &peer.Meta.Flags.DisableDNS,
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
},
}
}

View File

@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
pr.Protocol = types.PolicyRuleProtocolUDP
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = types.PolicyRuleProtocolICMP
case api.PolicyRuleUpdateProtocolNetbirdSsh:
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
}
}
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
for _, sourceGroupID := range pr.Sources {
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
return
}
}
pr.AuthorizedGroups = *rule.AuthorizedGroups
}
// validate policy object
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
DestinationResource: r.DestinationResource.ToAPIResponse(),
}
if len(r.AuthorizedGroups) != 0 {
authorizedGroupsCopy := r.AuthorizedGroups
rule.AuthorizedGroups = &authorizedGroupsCopy
}
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy

View File

@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
}
for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID {
return peer, nil

View File

@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
})
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -1910,16 +1910,16 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
if len(policyIDs) == 0 {
return nil, nil
}
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, policyIDs)
if err != nil {
return nil, err
}
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
var r types.PolicyRule
var dest, destRes, sources, sourceRes, ports, portRanges []byte
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
var enabled, bidirectional sql.NullBool
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &r.AuthorizedUser)
if err == nil {
if enabled.Valid {
r.Enabled = enabled.Bool
@@ -1945,6 +1945,9 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
if portRanges != nil {
_ = json.Unmarshal(portRanges, &r.PortRanges)
}
if authorizedGroups != nil {
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
}
}
return &r, err
})

View File

@@ -16,6 +16,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -45,8 +46,10 @@ const (
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
nativeSSHPortString = "22022"
nativeSSHPortNumber = 22022
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
defaultSSHPortString = "22"
defaultSSHPortNumber = 22
)
type supportedFeatures struct {
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
sshEnabled := false
for _, policy := range a.Policies {
if !policy.Enabled {
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := groupIDToUserIDs[groupID]
if !ok {
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
}
}
return getAccumulatedResources()
peers, fwRules := getAccumulatedResources()
return peers, fwRules, authorizedUsers, sshEnabled
}
func (a *Account) getAllowedUserIDs() map[string]struct{} {
users := make(map[string]struct{})
for _, nbUser := range a.Users {
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
users[nbUser.Id] = struct{}{}
}
}
return users
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
peersExists[peer.ID] = struct{}{}
}
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
Protocol: string(protocol),
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
}
}
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
return true
}
}
return false
}
func portsIncludesSSH(ports []string) bool {
for _, port := range ports {
if port == defaultSSHPortString || port == nativeSSHPortString {
return true
}
}
return false
}
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
@@ -1660,6 +1743,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
return nil
}
func (a *Account) GetActiveGroupUsers() map[string][]string {
allGroupID := ""
group, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
} else {
allGroupID = group.ID
}
groups := make(map[string][]string, len(a.GroupsG))
for _, user := range a.Users {
if !user.IsBlocked() && !user.IsServiceUser {
for _, groupID := range user.AutoGroups {
groups[groupID] = append(groups[groupID], user.Id)
}
groups[allGroupID] = append(groups[allGroupID], user.Id)
}
}
return groups
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
@@ -1691,7 +1794,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
expanded = append(expanded, &fr)
}
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
expanded = addNativeSSHRule(base, expanded)
}

View File

@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
}
}
func Test_GetActiveGroupUsers(t *testing.T) {
tests := []struct {
name string
account *Account
expected map[string][]string
}{
{
name: "all users are active",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1", "user2"},
"group3": {"user2"},
"": {"user1", "user2", "user3"},
},
},
{
name: "some users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1", "group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1"},
"group3": {"user3"},
"": {"user1", "user3"},
},
},
{
name: "all users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: true,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2"},
Blocked: true,
},
},
},
expected: map[string][]string{},
},
{
name: "user with no auto groups",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user2"},
"": {"user1", "user2"},
},
},
{
name: "empty account",
account: &Account{
Users: map[string]*User{},
},
expected: map[string][]string{},
},
{
name: "multiple users in same group",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user2", "user3"},
"": {"user1", "user2", "user3"},
},
},
{
name: "user in multiple groups with blocked users",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2", "group3"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1", "group2"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1"},
"group2": {"user1"},
"group3": {"user1", "user3"},
"": {"user1", "user3"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetActiveGroupUsers()
// Check that the number of groups matches
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
// Check each group's users
for groupID, expectedUsers := range tt.expected {
actualUsers, exists := result[groupID]
assert.True(t, exists, "group %s should exist in result", groupID)
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
}
// Ensure no extra groups in result
for groupID := range result {
_, exists := tt.expected[groupID]
assert.True(t, exists, "unexpected group %s in result", groupID)
}
})
}
}
func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct {
name string

View File

@@ -38,6 +38,8 @@ type NetworkMap struct {
FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule
ForwardingRules []*ForwardingRule
AuthorizedUsers map[string]map[string]struct{}
EnableSSH bool
}
func (nm *NetworkMap) Merge(other *NetworkMap) {

View File

@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap)
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})

View File

@@ -23,6 +23,8 @@ const (
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
// PolicyRuleProtocolNetbirdSSH type of traffic
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
)
const (
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
protocol = PolicyRuleProtocolUDP
case "icmp":
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
case "netbird-ssh":
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
default:
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
}

View File

@@ -80,6 +80,12 @@ type PolicyRule struct {
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
AuthorizedUser string
}
// Copy returns a copy of a policy rule
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
AuthorizedUser: pm.AuthorizedUser,
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
for k, v := range pm.AuthorizedGroups {
rule.AuthorizedGroups[k] = make([]string, len(v))
copy(rule.AuthorizedGroups[k], v)
}
return rule
}

View File

@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
_, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
}
if userHadPeers {
updateAccountPeers = true
}
updateAccountPeers = true
err = transaction.SaveUser(ctx, updatedUser)
if err != nil {
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
}
if settings.GroupsPropagationEnabled && updateAccountPeers {
if updateAccountPeers {
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err)
}

View File

@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a new regular user should not update account peers and not send peer update
// Creating a new regular user should send peer update (as users are not filtered yet)
t.Run("creating new regular user with no groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
// updating user with no linked peers should not update account peers and not send peer update
// updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
t.Run("updating user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -488,6 +488,8 @@ components:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
local_flags:
$ref: '#/components/schemas/PeerLocalFlags'
required:
- city_name
- connected
@@ -514,6 +516,49 @@ components:
- serial_number
- extra_dns_labels
- ephemeral
PeerLocalFlags:
type: object
properties:
rosenpass_enabled:
description: Indicates whether Rosenpass is enabled on this peer
type: boolean
example: true
rosenpass_permissive:
description: Indicates whether Rosenpass is in permissive mode or not
type: boolean
example: false
server_ssh_allowed:
description: Indicates whether SSH access this peer is allowed or not
type: boolean
example: true
disable_client_routes:
description: Indicates whether client routes are disabled on this peer or not
type: boolean
example: false
disable_server_routes:
description: Indicates whether server routes are disabled on this peer or not
type: boolean
example: false
disable_dns:
description: Indicates whether DNS management is disabled on this peer or not
type: boolean
example: false
disable_firewall:
description: Indicates whether firewall management is disabled on this peer or not
type: boolean
example: false
block_lan_access:
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
type: boolean
example: false
block_inbound:
description: Indicates whether inbound traffic is blocked on this peer
type: boolean
example: false
lazy_connection_enabled:
description: Indicates whether lazy connection is enabled on this peer
type: boolean
example: false
PeerTemporaryAccessRequest:
type: object
properties:
@@ -936,7 +981,7 @@ components:
protocol:
description: Policy rule type of the traffic
type: string
enum: ["all", "tcp", "udp", "icmp"]
enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
example: "tcp"
ports:
description: Policy rule affected ports
@@ -949,6 +994,14 @@ components:
type: array
items:
$ref: '#/components/schemas/RulePortRange'
authorized_groups:
description: Map of user group ids to a list of local users
type: object
additionalProperties:
type: array
items:
type: string
example: "group1"
required:
- name
- enabled

View File

@@ -130,10 +130,11 @@ const (
// Defines values for PolicyRuleProtocol.
const (
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
)
// Defines values for PolicyRuleMinimumAction.
@@ -144,10 +145,11 @@ const (
// Defines values for PolicyRuleMinimumProtocol.
const (
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
)
// Defines values for PolicyRuleUpdateAction.
@@ -158,10 +160,11 @@ const (
// Defines values for PolicyRuleUpdateProtocol.
const (
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
)
// Defines values for ResourceType.
@@ -1077,7 +1080,8 @@ type Peer struct {
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
LastSeen time.Time `json:"last_seen"`
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
LastSeen time.Time `json:"last_seen"`
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
Version string `json:"version"`
}
// PeerLocalFlags defines model for PeerLocalFlags.
type PeerLocalFlags struct {
// BlockInbound Indicates whether inbound traffic is blocked on this peer
BlockInbound *bool `json:"block_inbound,omitempty"`
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
// DisableDns Indicates whether DNS management is disabled on this peer or not
DisableDns *bool `json:"disable_dns,omitempty"`
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
DisableFirewall *bool `json:"disable_firewall,omitempty"`
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
}
// PeerMinimum defines model for PeerMinimum.
type PeerMinimum struct {
// Id Peer ID
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
// Action Policy rule accept or drops packets
Action PolicyRuleAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
// Action Policy rule accept or drops packets
Action PolicyRuleMinimumAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
// Action Policy rule accept or drops packets
Action PolicyRuleUpdateAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"`

File diff suppressed because it is too large Load Diff

View File

@@ -332,6 +332,24 @@ message NetworkMap {
bool routesFirewallRulesIsEmpty = 11;
repeated ForwardingRule forwardingRules = 12;
// SSHAuth represents SSH authorization configuration
SSHAuth sshAuth = 13;
}
message SSHAuth {
// UserIDClaim is the JWT claim to be used to get the users ID
string UserIDClaim = 1;
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
repeated bytes AuthorizedUsers = 2;
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
map<string, MachineUserIndexes> machine_users = 3;
}
message MachineUserIndexes {
repeated uint32 indexes = 1;
}
// RemotePeerConfig represents a configuration of a remote peer.

View File

@@ -0,0 +1,28 @@
package sshauth
import (
"encoding/hex"
"golang.org/x/crypto/blake2b"
)
// UserIDHash represents a hashed user ID (BLAKE2b-128)
type UserIDHash [16]byte
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
// This function must produce the same hash on both client and management server
func HashUserID(userID string) (UserIDHash, error) {
hash, err := blake2b.New(16, nil)
if err != nil {
return UserIDHash{}, err
}
hash.Write([]byte(userID))
var result UserIDHash
copy(result[:], hash.Sum(nil))
return result, nil
}
// String returns the hexadecimal string representation of the hash
func (h UserIDHash) String() string {
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,210 @@
package sshauth
import (
"testing"
)
func TestHashUserID(t *testing.T) {
tests := []struct {
name string
userID string
}{
{
name: "simple user ID",
userID: "user@example.com",
},
{
name: "UUID format",
userID: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "numeric ID",
userID: "12345",
},
{
name: "empty string",
userID: "",
},
{
name: "special characters",
userID: "user+test@domain.com",
},
{
name: "unicode characters",
userID: "用户@example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := HashUserID(tt.userID)
if err != nil {
t.Errorf("HashUserID() error = %v, want nil", err)
return
}
// Verify hash is non-zero for non-empty inputs
if tt.userID != "" && hash == [16]byte{} {
t.Errorf("HashUserID() returned zero hash for non-empty input")
}
})
}
}
func TestHashUserID_Consistency(t *testing.T) {
userID := "test@example.com"
hash1, err1 := HashUserID(userID)
if err1 != nil {
t.Fatalf("First HashUserID() error = %v", err1)
}
hash2, err2 := HashUserID(userID)
if err2 != nil {
t.Fatalf("Second HashUserID() error = %v", err2)
}
if hash1 != hash2 {
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
}
}
func TestHashUserID_Uniqueness(t *testing.T) {
tests := []struct {
userID1 string
userID2 string
}{
{"user1@example.com", "user2@example.com"},
{"alice@domain.com", "bob@domain.com"},
{"test", "test1"},
{"", "a"},
}
for _, tt := range tests {
hash1, err1 := HashUserID(tt.userID1)
if err1 != nil {
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
}
hash2, err2 := HashUserID(tt.userID2)
if err2 != nil {
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
}
if hash1 == hash2 {
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
}
}
}
func TestUserIDHash_String(t *testing.T) {
tests := []struct {
name string
hash UserIDHash
expected string
}{
{
name: "zero hash",
hash: [16]byte{},
expected: "00000000000000000000000000000000",
},
{
name: "small value",
hash: [16]byte{15: 0xff},
expected: "000000000000000000000000000000ff",
},
{
name: "large value",
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
expected: "0000000000000000deadbeefcafebabe",
},
{
name: "max value",
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
expected: "ffffffffffffffffffffffffffffffff",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.hash.String()
if result != tt.expected {
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
}
})
}
}
func TestUserIDHash_String_Length(t *testing.T) {
// Test that String() always returns 32 hex characters (16 bytes * 2)
userID := "test@example.com"
hash, err := HashUserID(userID)
if err != nil {
t.Fatalf("HashUserID() error = %v", err)
}
result := hash.String()
if len(result) != 32 {
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
}
// Verify it's valid hex
for i, c := range result {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
}
}
}
func TestHashUserID_KnownValues(t *testing.T) {
// Test with known BLAKE2b-128 values to ensure correct implementation
tests := []struct {
name string
userID string
expected UserIDHash
}{
{
name: "empty string",
userID: "",
// BLAKE2b-128 of empty string
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
},
{
name: "single character 'a'",
userID: "a",
// BLAKE2b-128 of "a"
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := HashUserID(tt.userID)
if err != nil {
t.Errorf("HashUserID() error = %v", err)
return
}
if hash != tt.expected {
t.Errorf("HashUserID(%q) = %x, want %x",
tt.userID, hash, tt.expected)
}
})
}
}
func BenchmarkHashUserID(b *testing.B) {
userID := "user@example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = HashUserID(userID)
}
}
func BenchmarkUserIDHash_String(b *testing.B) {
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = hash.String()
}
}