[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -20,22 +20,23 @@ const (
)
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
SubdivisionCode string
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -105,6 +106,11 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
cityName = &a.GeoLocation.CityName
}
var subdivisionCode *string
if a.SubdivisionCode != "" {
subdivisionCode = &a.SubdivisionCode
}
var protocol *string
if a.Protocol != "" {
p := string(a.Protocol)
@@ -112,22 +118,23 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
}
return &api.ProxyAccessLog{
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
Protocol: protocol,
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
SubdivisionCode: subdivisionCode,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
Protocol: protocol,
}
}

View File

@@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
logEntry.GeoLocation.CityName = location.City.Names.En
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
if len(location.Subdivisions) > 0 {
logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"math/rand/v2"
"net/http"
"os"
"slices"
"strconv"
@@ -229,6 +230,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return fmt.Errorf("hash secrets: %w", err)
}
for i, h := range service.Auth.HeaderAuths {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
@@ -488,6 +495,9 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
@@ -544,18 +554,52 @@ func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
svc.Auth.PasswordAuth.Password == "" {
svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
svc.Auth.PinAuth.Pin == "" {
svc.Auth.PinAuth = existingService.Auth.PinAuth
}
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}
// preserveHeaderAuthHashes fills in empty header auth values from the existing
// service so that unchanged secrets are not lost on update.
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
if len(headers) == 0 || len(existing) == 0 {
return
}
existingByHeader := make(map[string]string, len(existing))
for _, h := range existing {
if h != nil && h.Value != "" {
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
}
}
for _, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
h.Value = hash
}
}
}
}
// validateHeaderAuthValues checks that all enabled header auths have a value
// (either freshly provided or preserved from the existing service).
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
for i, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
return nil
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
@@ -605,6 +649,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
}
return nil

View File

@@ -7,14 +7,15 @@ import (
"math/big"
"net"
"net/http"
"net/netip"
"net/url"
"regexp"
"slices"
"strconv"
"strings"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
@@ -91,10 +92,37 @@ type BearerAuthConfig struct {
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
// HeaderAuthConfig defines a static header-value auth check.
// The proxy compares the incoming header value against the stored hash.
type HeaderAuthConfig struct {
Enabled bool `json:"enabled"`
Header string `json:"header"`
Value string `json:"value"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"`
}
// AccessRestrictions controls who can connect to the service based on IP or geography.
type AccessRestrictions struct {
AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"`
BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"`
AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"`
BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"`
}
// Copy returns a deep copy of the AccessRestrictions.
func (r AccessRestrictions) Copy() AccessRestrictions {
return AccessRestrictions{
AllowedCIDRs: slices.Clone(r.AllowedCIDRs),
BlockedCIDRs: slices.Clone(r.BlockedCIDRs),
AllowedCountries: slices.Clone(r.AllowedCountries),
BlockedCountries: slices.Clone(r.BlockedCountries),
}
}
func (a *AuthConfig) HashSecrets() error {
@@ -114,6 +142,16 @@ func (a *AuthConfig) HashSecrets() error {
a.PinAuth.Pin = hashedPin
}
for i, h := range a.HeaderAuths {
if h != nil && h.Enabled && h.Value != "" {
hashedValue, err := argon2id.Hash(h.Value)
if err != nil {
return fmt.Errorf("hash header auth[%d] value: %w", i, err)
}
h.Value = hashedValue
}
}
return nil
}
@@ -124,6 +162,11 @@ func (a *AuthConfig) ClearSecrets() {
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
for _, h := range a.HeaderAuths {
if h != nil {
h.Value = ""
}
}
}
type Meta struct {
@@ -143,12 +186,13 @@ type Service struct {
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
Auth AuthConfig `gorm:"serializer:json"`
Restrictions AccessRestrictions `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
// Mode determines the service type: "http", "tcp", "udp", or "tls".
Mode string `gorm:"default:'http'"`
ListenPort uint16
@@ -188,6 +232,20 @@ func (s *Service) ToAPIResponse() *api.Service {
}
}
if len(s.Auth.HeaderAuths) > 0 {
apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths))
for _, h := range s.Auth.HeaderAuths {
if h == nil {
continue
}
apiHeaders = append(apiHeaders, api.HeaderAuthConfig{
Enabled: h.Enabled,
Header: h.Header,
})
}
authConfig.HeaderAuths = &apiHeaders
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
@@ -222,18 +280,19 @@ func (s *Service) ToAPIResponse() *api.Service {
listenPort := int(s.ListenPort)
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
AccessRestrictions: restrictionsToAPI(s.Restrictions),
Meta: meta,
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
}
if s.ProxyCluster != "" {
@@ -263,7 +322,16 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
auth.Oidc = true
}
return &proto.ProxyMapping{
for _, h := range s.Auth.HeaderAuths {
if h != nil && h.Enabled {
auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{
Header: h.Header,
HashedValue: h.Value,
})
}
}
mapping := &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
@@ -276,6 +344,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec
}
if r := restrictionsToProto(s.Restrictions); r != nil {
mapping.AccessRestrictions = r
}
return mapping
}
// buildPathMappings constructs PathMapping entries from targets.
@@ -334,8 +408,7 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
panic(fmt.Sprintf("unknown operation type: %v", op))
}
}
@@ -477,6 +550,10 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
s.Auth = authFromAPI(req.Auth)
}
if req.AccessRestrictions != nil {
s.Restrictions = restrictionsFromAPI(req.AccessRestrictions)
}
return nil
}
@@ -538,9 +615,70 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
}
auth.BearerAuth = bearerAuth
}
if reqAuth.HeaderAuths != nil {
for _, h := range *reqAuth.HeaderAuths {
auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{
Enabled: h.Enabled,
Header: h.Header,
Value: h.Value,
})
}
}
return auth
}
func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
if r == nil {
return AccessRestrictions{}
}
var res AccessRestrictions
if r.AllowedCidrs != nil {
res.AllowedCIDRs = *r.AllowedCidrs
}
if r.BlockedCidrs != nil {
res.BlockedCIDRs = *r.BlockedCidrs
}
if r.AllowedCountries != nil {
res.AllowedCountries = *r.AllowedCountries
}
if r.BlockedCountries != nil {
res.BlockedCountries = *r.BlockedCountries
}
return res
}
func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
return nil
}
res := &api.AccessRestrictions{}
if len(r.AllowedCIDRs) > 0 {
res.AllowedCidrs = &r.AllowedCIDRs
}
if len(r.BlockedCIDRs) > 0 {
res.BlockedCidrs = &r.BlockedCIDRs
}
if len(r.AllowedCountries) > 0 {
res.AllowedCountries = &r.AllowedCountries
}
if len(r.BlockedCountries) > 0 {
res.BlockedCountries = &r.BlockedCountries
}
return res
}
func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
return nil
}
return &proto.AccessRestrictions{
AllowedCidrs: r.AllowedCIDRs,
BlockedCidrs: r.BlockedCIDRs,
AllowedCountries: r.AllowedCountries,
BlockedCountries: r.BlockedCountries,
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
@@ -557,6 +695,13 @@ func (s *Service) Validate() error {
s.Mode = ModeHTTP
}
if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil {
return err
}
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
switch s.Mode {
case ModeHTTP:
return s.validateHTTPMode()
@@ -657,6 +802,21 @@ func (s *Service) validateL4Target(target *Target) error {
if target.Path != nil && *target.Path != "" && *target.Path != "/" {
return errors.New("path is not supported for L4 services")
}
if target.Options.SessionIdleTimeout < 0 {
return errors.New("session_idle_timeout must be positive for L4 services")
}
if target.Options.RequestTimeout < 0 {
return errors.New("request_timeout must be positive for L4 services")
}
if target.Options.SkipTLSVerify {
return errors.New("skip_tls_verify is not supported for L4 services")
}
if target.Options.PathRewrite != "" {
return errors.New("path_rewrite is not supported for L4 services")
}
if len(target.Options.CustomHeaders) > 0 {
return errors.New("custom_headers is not supported for L4 services")
}
return nil
}
@@ -688,11 +848,9 @@ func IsPortBasedProtocol(mode string) bool {
}
const (
maxRequestTimeout = 5 * time.Minute
maxSessionIdleTimeout = 10 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
@@ -731,22 +889,12 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
if opts.RequestTimeout < 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.SessionIdleTimeout != 0 {
if opts.SessionIdleTimeout <= 0 {
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
}
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
}
if opts.SessionIdleTimeout < 0 {
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
@@ -796,6 +944,93 @@ func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
seen := make(map[string]struct{})
for i, h := range headers {
if h == nil || !h.Enabled {
continue
}
if h.Header == "" {
return fmt.Errorf("header_auths[%d]: header name is required", i)
}
if !httpHeaderNameRe.MatchString(h.Header) {
return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header)
}
canonical := http.CanonicalHeaderKey(h.Header)
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header)
}
if canonical == "Host" {
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
}
if _, dup := seen[canonical]; dup {
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
}
seen[canonical] = struct{}{}
if len(h.Value) > maxHeaderValueLen {
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
}
}
return nil
}
const (
maxCIDREntries = 200
maxCountryEntries = 50
)
// validateAccessRestrictions validates and normalizes access restriction
// entries. Country codes are uppercased in place.
func validateAccessRestrictions(r *AccessRestrictions) error {
if len(r.AllowedCIDRs) > maxCIDREntries {
return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries)
}
if len(r.BlockedCIDRs) > maxCIDREntries {
return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries)
}
if len(r.AllowedCountries) > maxCountryEntries {
return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries)
}
if len(r.BlockedCountries) > maxCountryEntries {
return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries)
}
for i, raw := range r.AllowedCIDRs {
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return fmt.Errorf("allowed_cidrs[%d]: %w", i, err)
}
if prefix != prefix.Masked() {
return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
}
}
for i, raw := range r.BlockedCIDRs {
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return fmt.Errorf("blocked_cidrs[%d]: %w", i, err)
}
if prefix != prefix.Masked() {
return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
}
}
for i, code := range r.AllowedCountries {
if len(code) != 2 {
return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
}
r.AllowedCountries[i] = strings.ToUpper(code)
}
for i, code := range r.BlockedCountries {
if len(code) != 2 {
return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
}
r.BlockedCountries[i] = strings.ToUpper(code)
}
return nil
}
func (s *Service) EventMeta() map[string]any {
meta := map[string]any{
"name": s.Name,
@@ -827,9 +1062,17 @@ func (s *Service) EventMeta() map[string]any {
}
func (s *Service) isAuthEnabled() bool {
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled)
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) {
return true
}
for _, h := range s.Auth.HeaderAuths {
if h != nil && h.Enabled {
return true
}
}
return false
}
func (s *Service) Copy() *Service {
@@ -866,6 +1109,16 @@ func (s *Service) Copy() *Service {
}
authCopy.BearerAuth = &ba
}
if len(s.Auth.HeaderAuths) > 0 {
authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths))
for i, h := range s.Auth.HeaderAuths {
if h == nil {
continue
}
hCopy := *h
authCopy.HeaderAuths[i] = &hCopy
}
}
return &Service{
ID: s.ID,
@@ -878,6 +1131,7 @@ func (s *Service) Copy() *Service {
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: authCopy,
Restrictions: s.Restrictions.Copy(),
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,

View File

@@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"valid 10m", 10 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
@@ -493,16 +494,17 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
// should be set on the copy.
func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
return &proto.ProxyMapping{
Type: m.Type,
Id: m.Id,
AccountId: m.AccountId,
Domain: m.Domain,
Path: m.Path,
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode,
ListenPort: m.ListenPort,
Type: m.Type,
Id: m.Id,
AccountId: m.AccountId,
Domain: m.Domain,
Path: m.Path,
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
}
}
@@ -561,6 +563,8 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
case *proto.AuthenticateRequest_Password:
return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth)
case *proto.AuthenticateRequest_HeaderAuth:
return s.authenticateHeader(ctx, req.GetId(), v, service.Auth.HeaderAuths)
default:
return false, "", ""
}
@@ -594,6 +598,35 @@ func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID
return true, "password-user", proxyauth.MethodPassword
}
func (s *ProxyServiceServer) authenticateHeader(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_HeaderAuth, auths []*rpservice.HeaderAuthConfig) (bool, string, proxyauth.Method) {
if len(auths) == 0 {
log.WithContext(ctx).Debugf("header authentication attempted but no header auths configured for service %s", serviceID)
return false, "", ""
}
headerName := http.CanonicalHeaderKey(req.HeaderAuth.GetHeaderName())
var lastErr error
for _, auth := range auths {
if auth == nil || !auth.Enabled {
continue
}
if headerName != "" && http.CanonicalHeaderKey(auth.Header) != headerName {
continue
}
if err := argon2id.Verify(req.HeaderAuth.GetHeaderValue(), auth.Value); err != nil {
lastErr = err
continue
}
return true, "header-user", proxyauth.MethodHeader
}
if lastErr != nil {
s.logAuthenticationError(ctx, lastErr, "Header")
}
return false, "", ""
}
func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) {
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType)
@@ -752,6 +785,9 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
}
if redirectURL.Scheme != "https" && redirectURL.Scheme != "http" {
return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme")
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
@@ -836,12 +872,9 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid.
// The HMAC is verified before consuming the PKCE verifier to prevent
// an attacker from invalidating a legitimate user's auth flow.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")
if len(parts) != 3 {
@@ -865,6 +898,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return "", "", errors.New("invalid state signature")
}
// Consume the PKCE verifier only after HMAC validation passes.
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
return verifier, redirectURL, nil
}