mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 00:54:01 -04:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
@@ -20,22 +20,23 @@ const (
|
||||
)
|
||||
|
||||
type AccessLogEntry struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ServiceID string `gorm:"index"`
|
||||
Timestamp time.Time `gorm:"index"`
|
||||
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
Method string `gorm:"index"`
|
||||
Host string `gorm:"index"`
|
||||
Path string `gorm:"index"`
|
||||
Duration time.Duration `gorm:"index"`
|
||||
StatusCode int `gorm:"index"`
|
||||
Reason string
|
||||
UserId string `gorm:"index"`
|
||||
AuthMethodUsed string `gorm:"index"`
|
||||
BytesUpload int64 `gorm:"index"`
|
||||
BytesDownload int64 `gorm:"index"`
|
||||
Protocol AccessLogProtocol `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ServiceID string `gorm:"index"`
|
||||
Timestamp time.Time `gorm:"index"`
|
||||
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
SubdivisionCode string
|
||||
Method string `gorm:"index"`
|
||||
Host string `gorm:"index"`
|
||||
Path string `gorm:"index"`
|
||||
Duration time.Duration `gorm:"index"`
|
||||
StatusCode int `gorm:"index"`
|
||||
Reason string
|
||||
UserId string `gorm:"index"`
|
||||
AuthMethodUsed string `gorm:"index"`
|
||||
BytesUpload int64 `gorm:"index"`
|
||||
BytesDownload int64 `gorm:"index"`
|
||||
Protocol AccessLogProtocol `gorm:"index"`
|
||||
}
|
||||
|
||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||
@@ -105,6 +106,11 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||
cityName = &a.GeoLocation.CityName
|
||||
}
|
||||
|
||||
var subdivisionCode *string
|
||||
if a.SubdivisionCode != "" {
|
||||
subdivisionCode = &a.SubdivisionCode
|
||||
}
|
||||
|
||||
var protocol *string
|
||||
if a.Protocol != "" {
|
||||
p := string(a.Protocol)
|
||||
@@ -112,22 +118,23 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||
}
|
||||
|
||||
return &api.ProxyAccessLog{
|
||||
Id: a.ID,
|
||||
ServiceId: a.ServiceID,
|
||||
Timestamp: a.Timestamp,
|
||||
Method: a.Method,
|
||||
Host: a.Host,
|
||||
Path: a.Path,
|
||||
DurationMs: int(a.Duration.Milliseconds()),
|
||||
StatusCode: a.StatusCode,
|
||||
SourceIp: sourceIP,
|
||||
Reason: reason,
|
||||
UserId: userID,
|
||||
AuthMethodUsed: authMethod,
|
||||
CountryCode: countryCode,
|
||||
CityName: cityName,
|
||||
BytesUpload: a.BytesUpload,
|
||||
BytesDownload: a.BytesDownload,
|
||||
Protocol: protocol,
|
||||
Id: a.ID,
|
||||
ServiceId: a.ServiceID,
|
||||
Timestamp: a.Timestamp,
|
||||
Method: a.Method,
|
||||
Host: a.Host,
|
||||
Path: a.Path,
|
||||
DurationMs: int(a.Duration.Milliseconds()),
|
||||
StatusCode: a.StatusCode,
|
||||
SourceIp: sourceIP,
|
||||
Reason: reason,
|
||||
UserId: userID,
|
||||
AuthMethodUsed: authMethod,
|
||||
CountryCode: countryCode,
|
||||
CityName: cityName,
|
||||
SubdivisionCode: subdivisionCode,
|
||||
BytesUpload: a.BytesUpload,
|
||||
BytesDownload: a.BytesDownload,
|
||||
Protocol: protocol,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
|
||||
logEntry.GeoLocation.CityName = location.City.Names.En
|
||||
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
|
||||
if len(location.Subdivisions) > 0 {
|
||||
logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -229,6 +230,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
for i, h := range service.Auth.HeaderAuths {
|
||||
if h != nil && h.Enabled && h.Value == "" {
|
||||
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
|
||||
}
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
@@ -488,6 +495,9 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
|
||||
return err
|
||||
}
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
m.preserveListenPort(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
@@ -544,18 +554,52 @@ func isHTTPFamily(mode string) bool {
|
||||
return mode == "" || mode == "http"
|
||||
}
|
||||
|
||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
|
||||
if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
svc.Auth.PasswordAuth.Password == "" {
|
||||
svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
svc.Auth.PinAuth.Pin == "" {
|
||||
svc.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
|
||||
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
|
||||
}
|
||||
|
||||
// preserveHeaderAuthHashes fills in empty header auth values from the existing
|
||||
// service so that unchanged secrets are not lost on update.
|
||||
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
|
||||
if len(headers) == 0 || len(existing) == 0 {
|
||||
return
|
||||
}
|
||||
existingByHeader := make(map[string]string, len(existing))
|
||||
for _, h := range existing {
|
||||
if h != nil && h.Value != "" {
|
||||
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
|
||||
}
|
||||
}
|
||||
for _, h := range headers {
|
||||
if h != nil && h.Enabled && h.Value == "" {
|
||||
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
|
||||
h.Value = hash
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateHeaderAuthValues checks that all enabled header auths have a value
|
||||
// (either freshly provided or preserved from the existing service).
|
||||
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
|
||||
for i, h := range headers {
|
||||
if h != nil && h.Enabled && h.Value == "" {
|
||||
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||
@@ -605,6 +649,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
}
|
||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||
}
|
||||
default:
|
||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -7,14 +7,15 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
@@ -91,10 +92,37 @@ type BearerAuthConfig struct {
|
||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HeaderAuthConfig defines a static header-value auth check.
|
||||
// The proxy compares the incoming header value against the stored hash.
|
||||
type HeaderAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Header string `json:"header"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||
HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AccessRestrictions controls who can connect to the service based on IP or geography.
|
||||
type AccessRestrictions struct {
|
||||
AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"`
|
||||
BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"`
|
||||
AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"`
|
||||
BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the AccessRestrictions.
|
||||
func (r AccessRestrictions) Copy() AccessRestrictions {
|
||||
return AccessRestrictions{
|
||||
AllowedCIDRs: slices.Clone(r.AllowedCIDRs),
|
||||
BlockedCIDRs: slices.Clone(r.BlockedCIDRs),
|
||||
AllowedCountries: slices.Clone(r.AllowedCountries),
|
||||
BlockedCountries: slices.Clone(r.BlockedCountries),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthConfig) HashSecrets() error {
|
||||
@@ -114,6 +142,16 @@ func (a *AuthConfig) HashSecrets() error {
|
||||
a.PinAuth.Pin = hashedPin
|
||||
}
|
||||
|
||||
for i, h := range a.HeaderAuths {
|
||||
if h != nil && h.Enabled && h.Value != "" {
|
||||
hashedValue, err := argon2id.Hash(h.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash header auth[%d] value: %w", i, err)
|
||||
}
|
||||
h.Value = hashedValue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -124,6 +162,11 @@ func (a *AuthConfig) ClearSecrets() {
|
||||
if a.PinAuth != nil {
|
||||
a.PinAuth.Pin = ""
|
||||
}
|
||||
for _, h := range a.HeaderAuths {
|
||||
if h != nil {
|
||||
h.Value = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
@@ -143,12 +186,13 @@ type Service struct {
|
||||
Enabled bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Restrictions AccessRestrictions `gorm:"serializer:json"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
// Mode determines the service type: "http", "tcp", "udp", or "tls".
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
@@ -188,6 +232,20 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.Auth.HeaderAuths) > 0 {
|
||||
apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths))
|
||||
for _, h := range s.Auth.HeaderAuths {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
apiHeaders = append(apiHeaders, api.HeaderAuthConfig{
|
||||
Enabled: h.Enabled,
|
||||
Header: h.Header,
|
||||
})
|
||||
}
|
||||
authConfig.HeaderAuths = &apiHeaders
|
||||
}
|
||||
|
||||
// Convert internal targets to API targets
|
||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
@@ -222,18 +280,19 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
listenPort := int(s.ListenPort)
|
||||
|
||||
resp := &api.Service{
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
Meta: meta,
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
AccessRestrictions: restrictionsToAPI(s.Restrictions),
|
||||
Meta: meta,
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -263,7 +322,16 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
auth.Oidc = true
|
||||
}
|
||||
|
||||
return &proto.ProxyMapping{
|
||||
for _, h := range s.Auth.HeaderAuths {
|
||||
if h != nil && h.Enabled {
|
||||
auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{
|
||||
Header: h.Header,
|
||||
HashedValue: h.Value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: operationToProtoType(operation),
|
||||
Id: s.ID,
|
||||
Domain: s.Domain,
|
||||
@@ -276,6 +344,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
}
|
||||
|
||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||
mapping.AccessRestrictions = r
|
||||
}
|
||||
|
||||
return mapping
|
||||
}
|
||||
|
||||
// buildPathMappings constructs PathMapping entries from targets.
|
||||
@@ -334,8 +408,7 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
case Delete:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
default:
|
||||
log.Fatalf("unknown operation type: %v", op)
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
panic(fmt.Sprintf("unknown operation type: %v", op))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,6 +550,10 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
s.Auth = authFromAPI(req.Auth)
|
||||
}
|
||||
|
||||
if req.AccessRestrictions != nil {
|
||||
s.Restrictions = restrictionsFromAPI(req.AccessRestrictions)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -538,9 +615,70 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
|
||||
}
|
||||
auth.BearerAuth = bearerAuth
|
||||
}
|
||||
if reqAuth.HeaderAuths != nil {
|
||||
for _, h := range *reqAuth.HeaderAuths {
|
||||
auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{
|
||||
Enabled: h.Enabled,
|
||||
Header: h.Header,
|
||||
Value: h.Value,
|
||||
})
|
||||
}
|
||||
}
|
||||
return auth
|
||||
}
|
||||
|
||||
func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
|
||||
if r == nil {
|
||||
return AccessRestrictions{}
|
||||
}
|
||||
var res AccessRestrictions
|
||||
if r.AllowedCidrs != nil {
|
||||
res.AllowedCIDRs = *r.AllowedCidrs
|
||||
}
|
||||
if r.BlockedCidrs != nil {
|
||||
res.BlockedCIDRs = *r.BlockedCidrs
|
||||
}
|
||||
if r.AllowedCountries != nil {
|
||||
res.AllowedCountries = *r.AllowedCountries
|
||||
}
|
||||
if r.BlockedCountries != nil {
|
||||
res.BlockedCountries = *r.BlockedCountries
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
|
||||
return nil
|
||||
}
|
||||
res := &api.AccessRestrictions{}
|
||||
if len(r.AllowedCIDRs) > 0 {
|
||||
res.AllowedCidrs = &r.AllowedCIDRs
|
||||
}
|
||||
if len(r.BlockedCIDRs) > 0 {
|
||||
res.BlockedCidrs = &r.BlockedCIDRs
|
||||
}
|
||||
if len(r.AllowedCountries) > 0 {
|
||||
res.AllowedCountries = &r.AllowedCountries
|
||||
}
|
||||
if len(r.BlockedCountries) > 0 {
|
||||
res.BlockedCountries = &r.BlockedCountries
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &proto.AccessRestrictions{
|
||||
AllowedCidrs: r.AllowedCIDRs,
|
||||
BlockedCidrs: r.BlockedCIDRs,
|
||||
AllowedCountries: r.AllowedCountries,
|
||||
BlockedCountries: r.BlockedCountries,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Validate() error {
|
||||
if s.Name == "" {
|
||||
return errors.New("service name is required")
|
||||
@@ -557,6 +695,13 @@ func (s *Service) Validate() error {
|
||||
s.Mode = ModeHTTP
|
||||
}
|
||||
|
||||
if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
return s.validateHTTPMode()
|
||||
@@ -657,6 +802,21 @@ func (s *Service) validateL4Target(target *Target) error {
|
||||
if target.Path != nil && *target.Path != "" && *target.Path != "/" {
|
||||
return errors.New("path is not supported for L4 services")
|
||||
}
|
||||
if target.Options.SessionIdleTimeout < 0 {
|
||||
return errors.New("session_idle_timeout must be positive for L4 services")
|
||||
}
|
||||
if target.Options.RequestTimeout < 0 {
|
||||
return errors.New("request_timeout must be positive for L4 services")
|
||||
}
|
||||
if target.Options.SkipTLSVerify {
|
||||
return errors.New("skip_tls_verify is not supported for L4 services")
|
||||
}
|
||||
if target.Options.PathRewrite != "" {
|
||||
return errors.New("path_rewrite is not supported for L4 services")
|
||||
}
|
||||
if len(target.Options.CustomHeaders) > 0 {
|
||||
return errors.New("custom_headers is not supported for L4 services")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -688,11 +848,9 @@ func IsPortBasedProtocol(mode string) bool {
|
||||
}
|
||||
|
||||
const (
|
||||
maxRequestTimeout = 5 * time.Minute
|
||||
maxSessionIdleTimeout = 10 * time.Minute
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
|
||||
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||
@@ -731,22 +889,12 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
|
||||
}
|
||||
|
||||
if opts.RequestTimeout != 0 {
|
||||
if opts.RequestTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||
}
|
||||
if opts.RequestTimeout > maxRequestTimeout {
|
||||
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
|
||||
}
|
||||
if opts.RequestTimeout < 0 {
|
||||
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||
}
|
||||
|
||||
if opts.SessionIdleTimeout != 0 {
|
||||
if opts.SessionIdleTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
|
||||
}
|
||||
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
|
||||
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
|
||||
}
|
||||
if opts.SessionIdleTimeout < 0 {
|
||||
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
|
||||
}
|
||||
|
||||
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||
@@ -796,6 +944,93 @@ func containsCRLF(s string) bool {
|
||||
return strings.ContainsAny(s, "\r\n")
|
||||
}
|
||||
|
||||
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
|
||||
seen := make(map[string]struct{})
|
||||
for i, h := range headers {
|
||||
if h == nil || !h.Enabled {
|
||||
continue
|
||||
}
|
||||
if h.Header == "" {
|
||||
return fmt.Errorf("header_auths[%d]: header name is required", i)
|
||||
}
|
||||
if !httpHeaderNameRe.MatchString(h.Header) {
|
||||
return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header)
|
||||
}
|
||||
canonical := http.CanonicalHeaderKey(h.Header)
|
||||
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||
return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header)
|
||||
}
|
||||
if _, ok := reservedHeaders[canonical]; ok {
|
||||
return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header)
|
||||
}
|
||||
if canonical == "Host" {
|
||||
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
|
||||
}
|
||||
if _, dup := seen[canonical]; dup {
|
||||
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
|
||||
}
|
||||
seen[canonical] = struct{}{}
|
||||
if len(h.Value) > maxHeaderValueLen {
|
||||
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxCIDREntries = 200
|
||||
maxCountryEntries = 50
|
||||
)
|
||||
|
||||
// validateAccessRestrictions validates and normalizes access restriction
|
||||
// entries. Country codes are uppercased in place.
|
||||
func validateAccessRestrictions(r *AccessRestrictions) error {
|
||||
if len(r.AllowedCIDRs) > maxCIDREntries {
|
||||
return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries)
|
||||
}
|
||||
if len(r.BlockedCIDRs) > maxCIDREntries {
|
||||
return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries)
|
||||
}
|
||||
if len(r.AllowedCountries) > maxCountryEntries {
|
||||
return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries)
|
||||
}
|
||||
if len(r.BlockedCountries) > maxCountryEntries {
|
||||
return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries)
|
||||
}
|
||||
|
||||
for i, raw := range r.AllowedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allowed_cidrs[%d]: %w", i, err)
|
||||
}
|
||||
if prefix != prefix.Masked() {
|
||||
return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
|
||||
}
|
||||
}
|
||||
for i, raw := range r.BlockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocked_cidrs[%d]: %w", i, err)
|
||||
}
|
||||
if prefix != prefix.Masked() {
|
||||
return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
|
||||
}
|
||||
}
|
||||
for i, code := range r.AllowedCountries {
|
||||
if len(code) != 2 {
|
||||
return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
|
||||
}
|
||||
r.AllowedCountries[i] = strings.ToUpper(code)
|
||||
}
|
||||
for i, code := range r.BlockedCountries {
|
||||
if len(code) != 2 {
|
||||
return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
|
||||
}
|
||||
r.BlockedCountries[i] = strings.ToUpper(code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
meta := map[string]any{
|
||||
"name": s.Name,
|
||||
@@ -827,9 +1062,17 @@ func (s *Service) EventMeta() map[string]any {
|
||||
}
|
||||
|
||||
func (s *Service) isAuthEnabled() bool {
|
||||
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
|
||||
if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
|
||||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
|
||||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled)
|
||||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) {
|
||||
return true
|
||||
}
|
||||
for _, h := range s.Auth.HeaderAuths {
|
||||
if h != nil && h.Enabled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) Copy() *Service {
|
||||
@@ -866,6 +1109,16 @@ func (s *Service) Copy() *Service {
|
||||
}
|
||||
authCopy.BearerAuth = &ba
|
||||
}
|
||||
if len(s.Auth.HeaderAuths) > 0 {
|
||||
authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths))
|
||||
for i, h := range s.Auth.HeaderAuths {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
hCopy := *h
|
||||
authCopy.HeaderAuths[i] = &hCopy
|
||||
}
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
@@ -878,6 +1131,7 @@ func (s *Service) Copy() *Service {
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Auth: authCopy,
|
||||
Restrictions: s.Restrictions.Copy(),
|
||||
Meta: s.Meta,
|
||||
SessionPrivateKey: s.SessionPrivateKey,
|
||||
SessionPublicKey: s.SessionPublicKey,
|
||||
|
||||
@@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
|
||||
}{
|
||||
{"valid 30s", 30 * time.Second, ""},
|
||||
{"valid 2m", 2 * time.Minute, ""},
|
||||
{"valid 10m", 10 * time.Minute, ""},
|
||||
{"zero is fine", 0, ""},
|
||||
{"negative", -1 * time.Second, "must be positive"},
|
||||
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -493,16 +494,17 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
||||
// should be set on the copy.
|
||||
func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
return &proto.ProxyMapping{
|
||||
Type: m.Type,
|
||||
Id: m.Id,
|
||||
AccountId: m.AccountId,
|
||||
Domain: m.Domain,
|
||||
Path: m.Path,
|
||||
Auth: m.Auth,
|
||||
PassHostHeader: m.PassHostHeader,
|
||||
RewriteRedirects: m.RewriteRedirects,
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
Type: m.Type,
|
||||
Id: m.Id,
|
||||
AccountId: m.AccountId,
|
||||
Domain: m.Domain,
|
||||
Path: m.Path,
|
||||
Auth: m.Auth,
|
||||
PassHostHeader: m.PassHostHeader,
|
||||
RewriteRedirects: m.RewriteRedirects,
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
AccessRestrictions: m.AccessRestrictions,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -561,6 +563,8 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
|
||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||
case *proto.AuthenticateRequest_Password:
|
||||
return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth)
|
||||
case *proto.AuthenticateRequest_HeaderAuth:
|
||||
return s.authenticateHeader(ctx, req.GetId(), v, service.Auth.HeaderAuths)
|
||||
default:
|
||||
return false, "", ""
|
||||
}
|
||||
@@ -594,6 +598,35 @@ func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID
|
||||
return true, "password-user", proxyauth.MethodPassword
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticateHeader(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_HeaderAuth, auths []*rpservice.HeaderAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if len(auths) == 0 {
|
||||
log.WithContext(ctx).Debugf("header authentication attempted but no header auths configured for service %s", serviceID)
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
headerName := http.CanonicalHeaderKey(req.HeaderAuth.GetHeaderName())
|
||||
|
||||
var lastErr error
|
||||
for _, auth := range auths {
|
||||
if auth == nil || !auth.Enabled {
|
||||
continue
|
||||
}
|
||||
if headerName != "" && http.CanonicalHeaderKey(auth.Header) != headerName {
|
||||
continue
|
||||
}
|
||||
if err := argon2id.Verify(req.HeaderAuth.GetHeaderValue(), auth.Value); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return true, "header-user", proxyauth.MethodHeader
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
s.logAuthenticationError(ctx, lastErr, "Header")
|
||||
}
|
||||
return false, "", ""
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) {
|
||||
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType)
|
||||
@@ -752,6 +785,9 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
}
|
||||
if redirectURL.Scheme != "https" && redirectURL.Scheme != "http" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme")
|
||||
}
|
||||
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
if err != nil {
|
||||
@@ -836,12 +872,9 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
|
||||
|
||||
// ValidateState validates the state parameter from an OAuth callback.
|
||||
// Returns the original redirect URL if valid, or an error if invalid.
|
||||
// The HMAC is verified before consuming the PKCE verifier to prevent
|
||||
// an attacker from invalidating a legitimate user's auth flow.
|
||||
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
||||
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
|
||||
if !ok {
|
||||
return "", "", errors.New("no verifier for state")
|
||||
}
|
||||
|
||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||
parts := strings.Split(state, "|")
|
||||
if len(parts) != 3 {
|
||||
@@ -865,6 +898,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
return "", "", errors.New("invalid state signature")
|
||||
}
|
||||
|
||||
// Consume the PKCE verifier only after HMAC validation passes.
|
||||
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
|
||||
if !ok {
|
||||
return "", "", errors.New("no verifier for state")
|
||||
}
|
||||
|
||||
return verifier, redirectURL, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user