[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -0,0 +1,146 @@
package jwt
import (
"errors"
"net/url"
"time"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/auth"
)
const (
// AccountIDSuffix suffix for the account id claim
AccountIDSuffix = "wt_account_id"
// DomainIDSuffix suffix for the domain id claim
DomainIDSuffix = "wt_account_domain"
// DomainCategorySuffix suffix for the domain category claim
DomainCategorySuffix = "wt_account_domain_category"
// UserIDClaim claim for the user id
UserIDClaim = "sub"
// LastLoginSuffix claim for the last login
LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
)
var (
errUserIDClaimEmpty = errors.New("user ID claim token value is empty")
)
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
authAudience string
userIDClaim string
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
type ClaimsExtractorOption func(*ClaimsExtractor)
// WithAudience sets the audience for the extractor
func WithAudience(audience string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.authAudience = audience
}
}
// WithUserIDClaim sets the user id claim for the extractor
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.userIDClaim = userIDClaim
}
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
ce := &ClaimsExtractor{}
for _, option := range options {
option(ce)
}
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
}
func parseTime(timeString string) time.Time {
if timeString == "" {
return time.Time{}
}
parsedTime, err := time.Parse(time.RFC3339, timeString)
if err != nil {
return time.Time{}
}
return parsedTime
}
func (c ClaimsExtractor) audienceClaim(claimName string) string {
url, err := url.JoinPath(c.authAudience, claimName)
if err != nil {
return c.authAudience + claimName // as it was previously
}
return url
}
// ToUserAuth extracts user authentication information from a JWT token
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
claims := token.Claims.(jwt.MapClaims)
userAuth := auth.UserAuth{}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return userAuth, errUserIDClaimEmpty
}
userAuth.UserId = userID
if accountIDClaim, ok := claims[c.audienceClaim(AccountIDSuffix)]; ok {
userAuth.AccountId = accountIDClaim.(string)
}
if domainClaim, ok := claims[c.audienceClaim(DomainIDSuffix)]; ok {
userAuth.Domain = domainClaim.(string)
}
if domainCategoryClaim, ok := claims[c.audienceClaim(DomainCategorySuffix)]; ok {
userAuth.DomainCategory = domainCategoryClaim.(string)
}
if lastLoginClaimString, ok := claims[c.audienceClaim(LastLoginSuffix)]; ok {
userAuth.LastLogin = parseTime(lastLoginClaimString.(string))
}
if invitedBool, ok := claims[c.audienceClaim(Invited)]; ok {
if value, ok := invitedBool.(bool); ok {
userAuth.Invited = value
}
}
return userAuth, nil
}
// ToGroups extracts group information from a JWT token
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
claims := token.Claims.(jwt.MapClaims)
userJWTGroups := make([]string, 0)
if claim, ok := claims[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
} else {
log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
}
} else {
log.Debugf("JWT claim %q is not a string array", claimName)
}
return userJWTGroups
}

View File

@@ -0,0 +1,288 @@
package jwt
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
)
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
}
// The supported elliptic curves types
const (
// p256 represents a cryptographic elliptical curve type.
p256 = "P-256"
// p384 represents a cryptographic elliptical curve type.
p384 = "P-384"
// p521 represents a cryptographic elliptical curve type.
p521 = "P-521"
)
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
X5c []string `json:"x5c"`
}
type Validator struct {
lock sync.Mutex
issuer string
audienceList []string
keysLocation string
idpSignkeyRefreshEnabled bool
keys *Jwks
}
var (
errKeyNotFound = errors.New("unable to find appropriate key")
errTokenEmpty = errors.New("required authorization token not found")
errTokenInvalid = errors.New("token is invalid")
errTokenParsing = errors.New("token could not be parsed")
)
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
keys, err := getPemKeys(keysLocation)
if err != nil {
log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err)
}
return &Validator{
keys: keys,
issuer: issuer,
audienceList: audienceList,
keysLocation: keysLocation,
idpSignkeyRefreshEnabled: idpSignkeyRefreshEnabled,
}
}
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
// If keys are rotated, verify the keys prior to token validation
if v.idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
// @todo propose a separate go routine to regularly check these to prevent blocking when actually
// validating the token
if !v.keys.stillValid() {
v.lock.Lock()
defer v.lock.Unlock()
refreshedKeys, err := getPemKeys(v.keysLocation)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = v.keys
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
v.keys = refreshedKeys
}
}
publicKey, err := getPublicKey(token, v.keys)
if err == nil {
return publicKey, nil
}
msg := fmt.Sprintf("getPublicKey error: %s", err)
if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled {
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
}
log.WithContext(ctx).Error(msg)
return nil, err
}
}
// ValidateAndParse validates the token and returns the parsed token
func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// If we get here, the required token is missing
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, errTokenEmpty
}
// Now parse the token
parsedToken, err := jwt.Parse(
token,
v.getKeyFunc(ctx),
jwt.WithAudience(v.audienceList...),
jwt.WithIssuer(v.issuer),
jwt.WithIssuedAt(),
)
// Check if there was an error in parsing...
if err != nil {
err = fmt.Errorf("%w: %s", errTokenParsing, err)
log.WithContext(ctx).Error(err.Error())
return nil, err
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
log.WithContext(ctx).Debug(errTokenInvalid.Error())
return nil, errTokenInvalid
}
return parsedToken, nil
}
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
jwks := &Jwks{}
url, err := url.ParseRequestURI(keysLocation)
if err != nil {
return jwks, err
}
resp, err := http.Get(url.String())
if err != nil {
return jwks, err
}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil {
return jwks, err
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, nil
}
func getPublicKey(token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid {
continue
}
if len(jwks.Keys[k].X5c) != 0 {
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
if jwks.Keys[k].Kty == "RSA" {
return getPublicKeyFromRSA(jwks.Keys[k])
}
if jwks.Keys[k].Kty == "EC" {
return getPublicKeyFromECDSA(jwks.Keys[k])
}
}
return nil, errKeyNotFound
}
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
return nil, fmt.Errorf("ecdsa key incomplete")
}
var xCoordinate []byte
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
return nil, err
}
var yCoordinate []byte
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
return nil, err
}
publicKey = &ecdsa.PublicKey{}
var curve elliptic.Curve
switch jwk.Crv {
case p256:
curve = elliptic.P256()
case p384:
curve = elliptic.P384()
case p521:
curve = elliptic.P521()
}
publicKey.Curve = curve
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
return publicKey, nil
}
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, err
}
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, err
}
var n, e big.Int
e.SetBytes(decodedE)
n.SetBytes(decodedN)
return &rsa.PublicKey{
E: int(e.Int64()),
N: &n,
}, nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
return 0
}
return maxAge
}
}
return 0
}

28
shared/auth/user.go Normal file
View File

@@ -0,0 +1,28 @@
package auth
import (
"time"
)
type UserAuth struct {
// The account id the user is accessing
AccountId string
// The account domain
Domain string
// The account domain category, TBC values
DomainCategory string
// Indicates whether this user was invited, TBC logic
Invited bool
// Indicates whether this is a child account
IsChild bool
// The user id
UserId string
// Last login time for this user
LastLogin time.Time
// The Groups the user belongs to on this account
Groups []string
// Indicates whether this user has authenticated with a Personal Access Token
IsPAT bool
}

View File

@@ -5,4 +5,4 @@ const (
AccountIDKey = "accountID"
UserIDKey = "userID"
PeerIDKey = "peerID"
)
)

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -117,7 +118,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package operations
// Operation represents a permission operation type
type Operation string
type Operation string

File diff suppressed because it is too large Load Diff

View File

@@ -146,6 +146,12 @@ message Flags {
bool blockInbound = 9;
bool lazyConnectionEnabled = 10;
bool enableSSHRoot = 11;
bool enableSSHSFTP = 12;
bool enableSSHLocalPortForwarding = 13;
bool enableSSHRemotePortForwarding = 14;
bool disableSSHAuth = 15;
}
// PeerSystemMeta is machine meta data like OS and version.
@@ -202,6 +208,8 @@ message NetbirdConfig {
RelayConfig relay = 4;
FlowConfig flow = 5;
JWTConfig jwt = 6;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@@ -240,6 +248,14 @@ message FlowConfig {
bool dnsCollection = 8;
}
// JWTConfig represents JWT authentication configuration
message JWTConfig {
string issuer = 1;
string audience = 2;
string keysLocation = 3;
int64 maxTokenAge = 4;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {
@@ -335,6 +351,8 @@ message SSHConfig {
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig.
bytes sshPubKey = 2;
JWTConfig jwtConfig = 3;
}
// DeviceAuthorizationFlowRequest empty struct for future expansion

View File

@@ -11,8 +11,8 @@ import (
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
nbnet "github.com/netbirdio/netbird/client/net"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
)
type Dialer struct {

View File

@@ -14,9 +14,9 @@ import (
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/client/net"
)
type Dialer struct {

View File

@@ -3,4 +3,4 @@ package relay
const (
// WebSocketURLPath is the path for the websocket relay connection
WebSocketURLPath = "/relay"
)
)