[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -20,22 +20,23 @@ const (
) )
type AccessLogEntry struct { type AccessLogEntry struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
ServiceID string `gorm:"index"` ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"` Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"` GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"` SubdivisionCode string
Host string `gorm:"index"` Method string `gorm:"index"`
Path string `gorm:"index"` Host string `gorm:"index"`
Duration time.Duration `gorm:"index"` Path string `gorm:"index"`
StatusCode int `gorm:"index"` Duration time.Duration `gorm:"index"`
Reason string StatusCode int `gorm:"index"`
UserId string `gorm:"index"` Reason string
AuthMethodUsed string `gorm:"index"` UserId string `gorm:"index"`
BytesUpload int64 `gorm:"index"` AuthMethodUsed string `gorm:"index"`
BytesDownload int64 `gorm:"index"` BytesUpload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"` BytesDownload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"`
} }
// FromProto creates an AccessLogEntry from a proto.AccessLog // FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -105,6 +106,11 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
cityName = &a.GeoLocation.CityName cityName = &a.GeoLocation.CityName
} }
var subdivisionCode *string
if a.SubdivisionCode != "" {
subdivisionCode = &a.SubdivisionCode
}
var protocol *string var protocol *string
if a.Protocol != "" { if a.Protocol != "" {
p := string(a.Protocol) p := string(a.Protocol)
@@ -112,22 +118,23 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
} }
return &api.ProxyAccessLog{ return &api.ProxyAccessLog{
Id: a.ID, Id: a.ID,
ServiceId: a.ServiceID, ServiceId: a.ServiceID,
Timestamp: a.Timestamp, Timestamp: a.Timestamp,
Method: a.Method, Method: a.Method,
Host: a.Host, Host: a.Host,
Path: a.Path, Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()), DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode, StatusCode: a.StatusCode,
SourceIp: sourceIP, SourceIp: sourceIP,
Reason: reason, Reason: reason,
UserId: userID, UserId: userID,
AuthMethodUsed: authMethod, AuthMethodUsed: authMethod,
CountryCode: countryCode, CountryCode: countryCode,
CityName: cityName, CityName: cityName,
BytesUpload: a.BytesUpload, SubdivisionCode: subdivisionCode,
BytesDownload: a.BytesDownload, BytesUpload: a.BytesUpload,
Protocol: protocol, BytesDownload: a.BytesDownload,
Protocol: protocol,
} }
} }

View File

@@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
logEntry.GeoLocation.CountryCode = location.Country.ISOCode logEntry.GeoLocation.CountryCode = location.Country.ISOCode
logEntry.GeoLocation.CityName = location.City.Names.En logEntry.GeoLocation.CityName = location.City.Names.En
logEntry.GeoLocation.GeoNameID = location.City.GeonameID logEntry.GeoLocation.GeoNameID = location.City.GeonameID
if len(location.Subdivisions) > 0 {
logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode
}
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"math/rand/v2" "math/rand/v2"
"net/http"
"os" "os"
"slices" "slices"
"strconv" "strconv"
@@ -229,6 +230,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return fmt.Errorf("hash secrets: %w", err) return fmt.Errorf("hash secrets: %w", err)
} }
for i, h := range service.Auth.HeaderAuths {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
keyPair, err := sessionkey.GenerateKeyPair() keyPair, err := sessionkey.GenerateKeyPair()
if err != nil { if err != nil {
return fmt.Errorf("generate session keys: %w", err) return fmt.Errorf("generate session keys: %w", err)
@@ -488,6 +495,9 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
} }
m.preserveExistingAuthSecrets(service, existingService) m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService) m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService) m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
@@ -544,18 +554,52 @@ func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http" return mode == "" || mode == "http"
} }
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" { svc.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
} }
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled && if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled && existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" { svc.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth svc.Auth.PinAuth = existingService.Auth.PinAuth
} }
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}
// preserveHeaderAuthHashes fills in empty header auth values from the existing
// service so that unchanged secrets are not lost on update.
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
if len(headers) == 0 || len(existing) == 0 {
return
}
existingByHeader := make(map[string]string, len(existing))
for _, h := range existing {
if h != nil && h.Value != "" {
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
}
}
for _, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
h.Value = hash
}
}
}
}
// validateHeaderAuthValues checks that all enabled header auths have a value
// (either freshly provided or preserved from the existing service).
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
for i, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
return nil
} }
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) { func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
@@ -605,6 +649,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
} }
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
} }
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
} }
} }
return nil return nil

View File

@@ -7,14 +7,15 @@ import (
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"regexp" "regexp"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
@@ -91,10 +92,37 @@ type BearerAuthConfig struct {
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"` DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
} }
// HeaderAuthConfig defines a static header-value auth check.
// The proxy compares the incoming header value against the stored hash.
type HeaderAuthConfig struct {
Enabled bool `json:"enabled"`
Header string `json:"header"`
Value string `json:"value"`
}
type AuthConfig struct { type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"` PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"` PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"` BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"`
}
// AccessRestrictions controls who can connect to the service based on IP or geography.
type AccessRestrictions struct {
AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"`
BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"`
AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"`
BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"`
}
// Copy returns a deep copy of the AccessRestrictions.
func (r AccessRestrictions) Copy() AccessRestrictions {
return AccessRestrictions{
AllowedCIDRs: slices.Clone(r.AllowedCIDRs),
BlockedCIDRs: slices.Clone(r.BlockedCIDRs),
AllowedCountries: slices.Clone(r.AllowedCountries),
BlockedCountries: slices.Clone(r.BlockedCountries),
}
} }
func (a *AuthConfig) HashSecrets() error { func (a *AuthConfig) HashSecrets() error {
@@ -114,6 +142,16 @@ func (a *AuthConfig) HashSecrets() error {
a.PinAuth.Pin = hashedPin a.PinAuth.Pin = hashedPin
} }
for i, h := range a.HeaderAuths {
if h != nil && h.Enabled && h.Value != "" {
hashedValue, err := argon2id.Hash(h.Value)
if err != nil {
return fmt.Errorf("hash header auth[%d] value: %w", i, err)
}
h.Value = hashedValue
}
}
return nil return nil
} }
@@ -124,6 +162,11 @@ func (a *AuthConfig) ClearSecrets() {
if a.PinAuth != nil { if a.PinAuth != nil {
a.PinAuth.Pin = "" a.PinAuth.Pin = ""
} }
for _, h := range a.HeaderAuths {
if h != nil {
h.Value = ""
}
}
} }
type Meta struct { type Meta struct {
@@ -143,12 +186,13 @@ type Service struct {
Enabled bool Enabled bool
PassHostHeader bool PassHostHeader bool
RewriteRedirects bool RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"` Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"` Restrictions AccessRestrictions `gorm:"serializer:json"`
SessionPrivateKey string `gorm:"column:session_private_key"` Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPublicKey string `gorm:"column:session_public_key"` SessionPrivateKey string `gorm:"column:session_private_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"` SessionPublicKey string `gorm:"column:session_public_key"`
SourcePeer string `gorm:"index:idx_service_source_peer"` Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
// Mode determines the service type: "http", "tcp", "udp", or "tls". // Mode determines the service type: "http", "tcp", "udp", or "tls".
Mode string `gorm:"default:'http'"` Mode string `gorm:"default:'http'"`
ListenPort uint16 ListenPort uint16
@@ -188,6 +232,20 @@ func (s *Service) ToAPIResponse() *api.Service {
} }
} }
if len(s.Auth.HeaderAuths) > 0 {
apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths))
for _, h := range s.Auth.HeaderAuths {
if h == nil {
continue
}
apiHeaders = append(apiHeaders, api.HeaderAuthConfig{
Enabled: h.Enabled,
Header: h.Header,
})
}
authConfig.HeaderAuths = &apiHeaders
}
// Convert internal targets to API targets // Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets)) apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets { for _, target := range s.Targets {
@@ -222,18 +280,19 @@ func (s *Service) ToAPIResponse() *api.Service {
listenPort := int(s.ListenPort) listenPort := int(s.ListenPort)
resp := &api.Service{ resp := &api.Service{
Id: s.ID, Id: s.ID,
Name: s.Name, Name: s.Name,
Domain: s.Domain, Domain: s.Domain,
Targets: apiTargets, Targets: apiTargets,
Enabled: s.Enabled, Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader, PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects, RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig, Auth: authConfig,
Meta: meta, AccessRestrictions: restrictionsToAPI(s.Restrictions),
Mode: &mode, Meta: meta,
ListenPort: &listenPort, Mode: &mode,
PortAutoAssigned: &s.PortAutoAssigned, ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
} }
if s.ProxyCluster != "" { if s.ProxyCluster != "" {
@@ -263,7 +322,16 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
auth.Oidc = true auth.Oidc = true
} }
return &proto.ProxyMapping{ for _, h := range s.Auth.HeaderAuths {
if h != nil && h.Enabled {
auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{
Header: h.Header,
HashedValue: h.Value,
})
}
}
mapping := &proto.ProxyMapping{
Type: operationToProtoType(operation), Type: operationToProtoType(operation),
Id: s.ID, Id: s.ID,
Domain: s.Domain, Domain: s.Domain,
@@ -276,6 +344,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
Mode: s.Mode, Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec ListenPort: int32(s.ListenPort), //nolint:gosec
} }
if r := restrictionsToProto(s.Restrictions); r != nil {
mapping.AccessRestrictions = r
}
return mapping
} }
// buildPathMappings constructs PathMapping entries from targets. // buildPathMappings constructs PathMapping entries from targets.
@@ -334,8 +408,7 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
case Delete: case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default: default:
log.Fatalf("unknown operation type: %v", op) panic(fmt.Sprintf("unknown operation type: %v", op))
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
} }
} }
@@ -477,6 +550,10 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
s.Auth = authFromAPI(req.Auth) s.Auth = authFromAPI(req.Auth)
} }
if req.AccessRestrictions != nil {
s.Restrictions = restrictionsFromAPI(req.AccessRestrictions)
}
return nil return nil
} }
@@ -538,9 +615,70 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
} }
auth.BearerAuth = bearerAuth auth.BearerAuth = bearerAuth
} }
if reqAuth.HeaderAuths != nil {
for _, h := range *reqAuth.HeaderAuths {
auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{
Enabled: h.Enabled,
Header: h.Header,
Value: h.Value,
})
}
}
return auth return auth
} }
func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
if r == nil {
return AccessRestrictions{}
}
var res AccessRestrictions
if r.AllowedCidrs != nil {
res.AllowedCIDRs = *r.AllowedCidrs
}
if r.BlockedCidrs != nil {
res.BlockedCIDRs = *r.BlockedCidrs
}
if r.AllowedCountries != nil {
res.AllowedCountries = *r.AllowedCountries
}
if r.BlockedCountries != nil {
res.BlockedCountries = *r.BlockedCountries
}
return res
}
func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
return nil
}
res := &api.AccessRestrictions{}
if len(r.AllowedCIDRs) > 0 {
res.AllowedCidrs = &r.AllowedCIDRs
}
if len(r.BlockedCIDRs) > 0 {
res.BlockedCidrs = &r.BlockedCIDRs
}
if len(r.AllowedCountries) > 0 {
res.AllowedCountries = &r.AllowedCountries
}
if len(r.BlockedCountries) > 0 {
res.BlockedCountries = &r.BlockedCountries
}
return res
}
func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
return nil
}
return &proto.AccessRestrictions{
AllowedCidrs: r.AllowedCIDRs,
BlockedCidrs: r.BlockedCIDRs,
AllowedCountries: r.AllowedCountries,
BlockedCountries: r.BlockedCountries,
}
}
func (s *Service) Validate() error { func (s *Service) Validate() error {
if s.Name == "" { if s.Name == "" {
return errors.New("service name is required") return errors.New("service name is required")
@@ -557,6 +695,13 @@ func (s *Service) Validate() error {
s.Mode = ModeHTTP s.Mode = ModeHTTP
} }
if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil {
return err
}
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
switch s.Mode { switch s.Mode {
case ModeHTTP: case ModeHTTP:
return s.validateHTTPMode() return s.validateHTTPMode()
@@ -657,6 +802,21 @@ func (s *Service) validateL4Target(target *Target) error {
if target.Path != nil && *target.Path != "" && *target.Path != "/" { if target.Path != nil && *target.Path != "" && *target.Path != "/" {
return errors.New("path is not supported for L4 services") return errors.New("path is not supported for L4 services")
} }
if target.Options.SessionIdleTimeout < 0 {
return errors.New("session_idle_timeout must be positive for L4 services")
}
if target.Options.RequestTimeout < 0 {
return errors.New("request_timeout must be positive for L4 services")
}
if target.Options.SkipTLSVerify {
return errors.New("skip_tls_verify is not supported for L4 services")
}
if target.Options.PathRewrite != "" {
return errors.New("path_rewrite is not supported for L4 services")
}
if len(target.Options.CustomHeaders) > 0 {
return errors.New("custom_headers is not supported for L4 services")
}
return nil return nil
} }
@@ -688,11 +848,9 @@ func IsPortBasedProtocol(mode string) bool {
} }
const ( const (
maxRequestTimeout = 5 * time.Minute maxCustomHeaders = 16
maxSessionIdleTimeout = 10 * time.Minute maxHeaderKeyLen = 128
maxCustomHeaders = 16 maxHeaderValueLen = 4096
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
) )
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition. // httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
@@ -731,22 +889,12 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite) return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
} }
if opts.RequestTimeout != 0 { if opts.RequestTimeout < 0 {
if opts.RequestTimeout <= 0 { return fmt.Errorf("target %d: request_timeout must be positive", idx)
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
} }
if opts.SessionIdleTimeout != 0 { if opts.SessionIdleTimeout < 0 {
if opts.SessionIdleTimeout <= 0 { return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
}
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
}
} }
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil { if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
@@ -796,6 +944,93 @@ func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n") return strings.ContainsAny(s, "\r\n")
} }
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
seen := make(map[string]struct{})
for i, h := range headers {
if h == nil || !h.Enabled {
continue
}
if h.Header == "" {
return fmt.Errorf("header_auths[%d]: header name is required", i)
}
if !httpHeaderNameRe.MatchString(h.Header) {
return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header)
}
canonical := http.CanonicalHeaderKey(h.Header)
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header)
}
if canonical == "Host" {
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
}
if _, dup := seen[canonical]; dup {
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
}
seen[canonical] = struct{}{}
if len(h.Value) > maxHeaderValueLen {
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
}
}
return nil
}
const (
maxCIDREntries = 200
maxCountryEntries = 50
)
// validateAccessRestrictions validates and normalizes access restriction
// entries. Country codes are uppercased in place.
func validateAccessRestrictions(r *AccessRestrictions) error {
if len(r.AllowedCIDRs) > maxCIDREntries {
return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries)
}
if len(r.BlockedCIDRs) > maxCIDREntries {
return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries)
}
if len(r.AllowedCountries) > maxCountryEntries {
return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries)
}
if len(r.BlockedCountries) > maxCountryEntries {
return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries)
}
for i, raw := range r.AllowedCIDRs {
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return fmt.Errorf("allowed_cidrs[%d]: %w", i, err)
}
if prefix != prefix.Masked() {
return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
}
}
for i, raw := range r.BlockedCIDRs {
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return fmt.Errorf("blocked_cidrs[%d]: %w", i, err)
}
if prefix != prefix.Masked() {
return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
}
}
for i, code := range r.AllowedCountries {
if len(code) != 2 {
return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
}
r.AllowedCountries[i] = strings.ToUpper(code)
}
for i, code := range r.BlockedCountries {
if len(code) != 2 {
return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
}
r.BlockedCountries[i] = strings.ToUpper(code)
}
return nil
}
func (s *Service) EventMeta() map[string]any { func (s *Service) EventMeta() map[string]any {
meta := map[string]any{ meta := map[string]any{
"name": s.Name, "name": s.Name,
@@ -827,9 +1062,17 @@ func (s *Service) EventMeta() map[string]any {
} }
func (s *Service) isAuthEnabled() bool { func (s *Service) isAuthEnabled() bool {
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) || if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) || (s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) (s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) {
return true
}
for _, h := range s.Auth.HeaderAuths {
if h != nil && h.Enabled {
return true
}
}
return false
} }
func (s *Service) Copy() *Service { func (s *Service) Copy() *Service {
@@ -866,6 +1109,16 @@ func (s *Service) Copy() *Service {
} }
authCopy.BearerAuth = &ba authCopy.BearerAuth = &ba
} }
if len(s.Auth.HeaderAuths) > 0 {
authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths))
for i, h := range s.Auth.HeaderAuths {
if h == nil {
continue
}
hCopy := *h
authCopy.HeaderAuths[i] = &hCopy
}
}
return &Service{ return &Service{
ID: s.ID, ID: s.ID,
@@ -878,6 +1131,7 @@ func (s *Service) Copy() *Service {
PassHostHeader: s.PassHostHeader, PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects, RewriteRedirects: s.RewriteRedirects,
Auth: authCopy, Auth: authCopy,
Restrictions: s.Restrictions.Copy(),
Meta: s.Meta, Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey, SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey, SessionPublicKey: s.SessionPublicKey,

View File

@@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
}{ }{
{"valid 30s", 30 * time.Second, ""}, {"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""}, {"valid 2m", 2 * time.Minute, ""},
{"valid 10m", 10 * time.Minute, ""},
{"zero is fine", 0, ""}, {"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"}, {"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
@@ -493,16 +494,17 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
// should be set on the copy. // should be set on the copy.
func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
return &proto.ProxyMapping{ return &proto.ProxyMapping{
Type: m.Type, Type: m.Type,
Id: m.Id, Id: m.Id,
AccountId: m.AccountId, AccountId: m.AccountId,
Domain: m.Domain, Domain: m.Domain,
Path: m.Path, Path: m.Path,
Auth: m.Auth, Auth: m.Auth,
PassHostHeader: m.PassHostHeader, PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects, RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode, Mode: m.Mode,
ListenPort: m.ListenPort, ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
} }
} }
@@ -561,6 +563,8 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
case *proto.AuthenticateRequest_Password: case *proto.AuthenticateRequest_Password:
return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth) return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth)
case *proto.AuthenticateRequest_HeaderAuth:
return s.authenticateHeader(ctx, req.GetId(), v, service.Auth.HeaderAuths)
default: default:
return false, "", "" return false, "", ""
} }
@@ -594,6 +598,35 @@ func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID
return true, "password-user", proxyauth.MethodPassword return true, "password-user", proxyauth.MethodPassword
} }
func (s *ProxyServiceServer) authenticateHeader(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_HeaderAuth, auths []*rpservice.HeaderAuthConfig) (bool, string, proxyauth.Method) {
if len(auths) == 0 {
log.WithContext(ctx).Debugf("header authentication attempted but no header auths configured for service %s", serviceID)
return false, "", ""
}
headerName := http.CanonicalHeaderKey(req.HeaderAuth.GetHeaderName())
var lastErr error
for _, auth := range auths {
if auth == nil || !auth.Enabled {
continue
}
if headerName != "" && http.CanonicalHeaderKey(auth.Header) != headerName {
continue
}
if err := argon2id.Verify(req.HeaderAuth.GetHeaderValue(), auth.Value); err != nil {
lastErr = err
continue
}
return true, "header-user", proxyauth.MethodHeader
}
if lastErr != nil {
s.logAuthenticationError(ctx, lastErr, "Header")
}
return false, "", ""
}
func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) { func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) {
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) { if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType) log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType)
@@ -752,6 +785,9 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
} }
if redirectURL.Scheme != "https" && redirectURL.Scheme != "http" {
return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme")
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId()) services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil { if err != nil {
@@ -836,12 +872,9 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback. // ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid. // Returns the original redirect URL if valid, or an error if invalid.
// The HMAC is verified before consuming the PKCE verifier to prevent
// an attacker from invalidating a legitimate user's auth flow.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|") parts := strings.Split(state, "|")
if len(parts) != 3 { if len(parts) != 3 {
@@ -865,6 +898,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return "", "", errors.New("invalid state signature") return "", "", errors.New("invalid state signature")
} }
// Consume the PKCE verifier only after HMAC validation passes.
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
return verifier, redirectURL, nil return verifier, redirectURL, nil
} }

View File

@@ -44,6 +44,12 @@ type Record struct {
GeonameID uint `maxminddb:"geoname_id"` GeonameID uint `maxminddb:"geoname_id"`
ISOCode string `maxminddb:"iso_code"` ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"` } `maxminddb:"country"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
} }
type City struct { type City struct {

View File

@@ -10,7 +10,7 @@ FROM gcr.io/distroless/base:debug
COPY netbird-proxy /go/bin/netbird-proxy COPY netbird-proxy /go/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group COPY --from=builder /tmp/group /etc/group
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
USER netbird:netbird USER netbird:netbird
ENV HOME=/var/lib/netbird ENV HOME=/var/lib/netbird

View File

@@ -28,7 +28,7 @@ FROM gcr.io/distroless/base:debug
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group COPY --from=builder /tmp/group /etc/group
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
USER netbird:netbird USER netbird:netbird
ENV HOME=/var/lib/netbird ENV HOME=/var/lib/netbird

View File

@@ -13,10 +13,11 @@ import (
type Method string type Method string
var ( const (
MethodPassword Method = "password" MethodPassword Method = "password"
MethodPIN Method = "pin" MethodPIN Method = "pin"
MethodOIDC Method = "oidc" MethodOIDC Method = "oidc"
MethodHeader Method = "header"
) )
func (m Method) String() string { func (m Method) String() string {

View File

@@ -36,31 +36,33 @@ var (
var ( var (
logLevel string logLevel string
debugLogs bool debugLogs bool
mgmtAddr string mgmtAddr string
addr string addr string
proxyDomain string proxyDomain string
defaultDialTimeout time.Duration maxDialTimeout time.Duration
certDir string maxSessionIdleTimeout time.Duration
acmeCerts bool certDir string
acmeAddr string acmeCerts bool
acmeDir string acmeAddr string
acmeEABKID string acmeDir string
acmeEABHMACKey string acmeEABKID string
acmeChallengeType string acmeEABHMACKey string
debugEndpoint bool acmeChallengeType string
debugEndpointAddr string debugEndpoint bool
healthAddr string debugEndpointAddr string
forwardedProto string healthAddr string
trustedProxies string forwardedProto string
certFile string trustedProxies string
certKeyFile string certFile string
certLockMethod string certKeyFile string
wildcardCertDir string certLockMethod string
wgPort uint16 wildcardCertDir string
proxyProtocol bool wgPort uint16
preSharedKey string proxyProtocol bool
supportsCustomPorts bool preSharedKey string
supportsCustomPorts bool
geoDataDir string
) )
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
@@ -99,7 +101,9 @@ func init() {
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)") rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
} }
// Execute runs the root command. // Execute runs the root command.
@@ -177,17 +181,15 @@ func runServer(cmd *cobra.Command, args []string) error {
ProxyProtocol: proxyProtocol, ProxyProtocol: proxyProtocol,
PreSharedKey: preSharedKey, PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts, SupportsCustomPorts: supportsCustomPorts,
DefaultDialTimeout: defaultDialTimeout, MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
GeoDataDir: geoDataDir,
} }
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop() defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil { return srv.ListenAndServe(ctx, addr)
logger.Error(err)
return err
}
return nil
} }
func envBoolOrDefault(key string, def bool) bool { func envBoolOrDefault(key string, def bool) bool {
@@ -197,6 +199,7 @@ func envBoolOrDefault(key string, def bool) bool {
} }
parsed, err := strconv.ParseBool(v) parsed, err := strconv.ParseBool(v)
if err != nil { if err != nil {
log.Warnf("parse %s=%q: %v, using default %v", key, v, err, def)
return def return def
} }
return parsed return parsed
@@ -217,6 +220,7 @@ func envUint16OrDefault(key string, def uint16) uint16 {
} }
parsed, err := strconv.ParseUint(v, 10, 16) parsed, err := strconv.ParseUint(v, 10, 16)
if err != nil { if err != nil {
log.Warnf("parse %s=%q: %v, using default %d", key, v, err, def)
return def return def
} }
return uint16(parsed) return uint16(parsed)
@@ -229,6 +233,7 @@ func envDurationOrDefault(key string, def time.Duration) time.Duration {
} }
parsed, err := time.ParseDuration(v) parsed, err := time.ParseDuration(v)
if err != nil { if err != nil {
log.Warnf("parse %s=%q: %v, using default %s", key, v, err, def)
return def return def
} }
return parsed return parsed

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
@@ -22,6 +23,16 @@ const (
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
logSendTimeout = 10 * time.Second logSendTimeout = 10 * time.Second
// denyCooldown is the min interval between deny log entries per service+reason
// to prevent flooding from denied connections (e.g. UDP packets from blocked IPs).
denyCooldown = 10 * time.Second
// maxDenyBuckets caps tracked deny rate-limit entries to bound memory under DDoS.
maxDenyBuckets = 10000
// maxLogWorkers caps concurrent gRPC send goroutines.
maxLogWorkers = 4096
) )
type domainUsage struct { type domainUsage struct {
@@ -38,6 +49,18 @@ type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
} }
// denyBucketKey identifies a rate-limited deny log stream.
type denyBucketKey struct {
ServiceID types.ServiceID
Reason string
}
// denyBucket tracks rate-limited deny log entries.
type denyBucket struct {
lastLogged time.Time
suppressed int64
}
// Logger sends access log entries to the management server via gRPC. // Logger sends access log entries to the management server via gRPC.
type Logger struct { type Logger struct {
client gRPCClient client gRPCClient
@@ -47,7 +70,12 @@ type Logger struct {
usageMux sync.Mutex usageMux sync.Mutex
domainUsage map[string]*domainUsage domainUsage map[string]*domainUsage
denyMu sync.Mutex
denyBuckets map[denyBucketKey]*denyBucket
logSem chan struct{}
cleanupCancel context.CancelFunc cleanupCancel context.CancelFunc
dropped atomic.Int64
} }
// NewLogger creates a new access log Logger. The trustedProxies parameter // NewLogger creates a new access log Logger. The trustedProxies parameter
@@ -64,6 +92,8 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
logger: logger, logger: logger,
trustedProxies: trustedProxies, trustedProxies: trustedProxies,
domainUsage: make(map[string]*domainUsage), domainUsage: make(map[string]*domainUsage),
denyBuckets: make(map[denyBucketKey]*denyBucket),
logSem: make(chan struct{}, maxLogWorkers),
cleanupCancel: cancel, cleanupCancel: cancel,
} }
@@ -83,7 +113,7 @@ func (l *Logger) Close() {
type logEntry struct { type logEntry struct {
ID string ID string
AccountID types.AccountID AccountID types.AccountID
ServiceId types.ServiceID ServiceID types.ServiceID
Host string Host string
Path string Path string
DurationMs int64 DurationMs int64
@@ -91,7 +121,7 @@ type logEntry struct {
ResponseCode int32 ResponseCode int32
SourceIP netip.Addr SourceIP netip.Addr
AuthMechanism string AuthMechanism string
UserId string UserID string
AuthSuccess bool AuthSuccess bool
BytesUpload int64 BytesUpload int64
BytesDownload int64 BytesDownload int64
@@ -118,6 +148,10 @@ type L4Entry struct {
DurationMs int64 DurationMs int64
BytesUpload int64 BytesUpload int64
BytesDownload int64 BytesDownload int64
// DenyReason, when non-empty, indicates the connection was denied.
// Values match the HTTP auth mechanism strings: "ip_restricted",
// "country_restricted", "geo_unavailable".
DenyReason string
} }
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP). // LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
@@ -126,7 +160,7 @@ func (l *Logger) LogL4(entry L4Entry) {
le := logEntry{ le := logEntry{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: entry.AccountID, AccountID: entry.AccountID,
ServiceId: entry.ServiceID, ServiceID: entry.ServiceID,
Protocol: entry.Protocol, Protocol: entry.Protocol,
Host: entry.Host, Host: entry.Host,
SourceIP: entry.SourceIP, SourceIP: entry.SourceIP,
@@ -134,10 +168,47 @@ func (l *Logger) LogL4(entry L4Entry) {
BytesUpload: entry.BytesUpload, BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload, BytesDownload: entry.BytesDownload,
} }
if entry.DenyReason != "" {
if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) {
return
}
le.AuthMechanism = entry.DenyReason
le.AuthSuccess = false
}
l.log(le) l.log(le)
l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload) l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload)
} }
// allowDenyLog rate-limits deny log entries per service+reason combination.
func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool {
key := denyBucketKey{ServiceID: serviceID, Reason: reason}
now := time.Now()
l.denyMu.Lock()
defer l.denyMu.Unlock()
b, ok := l.denyBuckets[key]
if !ok {
if len(l.denyBuckets) >= maxDenyBuckets {
return false
}
l.denyBuckets[key] = &denyBucket{lastLogged: now}
return true
}
if now.Sub(b.lastLogged) >= denyCooldown {
if b.suppressed > 0 {
l.logger.Debugf("access restriction: suppressed %d deny log entries for %s (%s)", b.suppressed, serviceID, reason)
}
b.lastLogged = now
b.suppressed = 0
return true
}
b.suppressed++
return false
}
func (l *Logger) log(entry logEntry) { func (l *Logger) log(entry logEntry) {
// Fire off the log request in a separate routine. // Fire off the log request in a separate routine.
// This increases the possibility of losing a log message // This increases the possibility of losing a log message
@@ -147,12 +218,21 @@ func (l *Logger) log(entry logEntry) {
// There is also a chance that log messages will arrive at // There is also a chance that log messages will arrive at
// the server out of order; however, the timestamp should // the server out of order; however, the timestamp should
// allow for resolving that on the server. // allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary. now := timestamppb.Now()
select {
case l.logSem <- struct{}{}:
default:
total := l.dropped.Add(1)
l.logger.Debugf("access log send dropped: worker limit reached (total dropped: %d)", total)
return
}
go func() { go func() {
defer func() { <-l.logSem }()
logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout) logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout)
defer cancel() defer cancel()
// Only OIDC sessions have a meaningful user identity.
if entry.AuthMechanism != auth.MethodOIDC.String() { if entry.AuthMechanism != auth.MethodOIDC.String() {
entry.UserId = "" entry.UserID = ""
} }
var sourceIP string var sourceIP string
@@ -165,7 +245,7 @@ func (l *Logger) log(entry logEntry) {
LogId: entry.ID, LogId: entry.ID,
AccountId: string(entry.AccountID), AccountId: string(entry.AccountID),
Timestamp: now, Timestamp: now,
ServiceId: string(entry.ServiceId), ServiceId: string(entry.ServiceID),
Host: entry.Host, Host: entry.Host,
Path: entry.Path, Path: entry.Path,
DurationMs: entry.DurationMs, DurationMs: entry.DurationMs,
@@ -173,7 +253,7 @@ func (l *Logger) log(entry logEntry) {
ResponseCode: entry.ResponseCode, ResponseCode: entry.ResponseCode,
SourceIp: sourceIP, SourceIp: sourceIP,
AuthMechanism: entry.AuthMechanism, AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId, UserId: entry.UserID,
AuthSuccess: entry.AuthSuccess, AuthSuccess: entry.AuthSuccess,
BytesUpload: entry.BytesUpload, BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload, BytesDownload: entry.BytesDownload,
@@ -181,7 +261,7 @@ func (l *Logger) log(entry logEntry) {
}, },
}); err != nil { }); err != nil {
l.logger.WithFields(log.Fields{ l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId, "service_id": entry.ServiceID,
"host": entry.Host, "host": entry.Host,
"path": entry.Path, "path": entry.Path,
"duration": entry.DurationMs, "duration": entry.DurationMs,
@@ -189,7 +269,7 @@ func (l *Logger) log(entry logEntry) {
"response_code": entry.ResponseCode, "response_code": entry.ResponseCode,
"source_ip": sourceIP, "source_ip": sourceIP,
"auth_mechanism": entry.AuthMechanism, "auth_mechanism": entry.AuthMechanism,
"user_id": entry.UserId, "user_id": entry.UserID,
"auth_success": entry.AuthSuccess, "auth_success": entry.AuthSuccess,
"error": err, "error": err,
}).Error("Error sending access log on gRPC connection") }).Error("Error sending access log on gRPC connection")
@@ -248,7 +328,7 @@ func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
} }
} }
// cleanupStaleUsage removes usage entries for domains that have been inactive. // cleanupStaleUsage removes usage and deny-rate-limit entries that have been inactive.
func (l *Logger) cleanupStaleUsage(ctx context.Context) { func (l *Logger) cleanupStaleUsage(ctx context.Context) {
ticker := time.NewTicker(usageCleanupPeriod) ticker := time.NewTicker(usageCleanupPeriod)
defer ticker.Stop() defer ticker.Stop()
@@ -258,20 +338,41 @@ func (l *Logger) cleanupStaleUsage(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
l.usageMux.Lock()
now := time.Now() now := time.Now()
removed := 0 l.cleanupDomainUsage(now)
for domain, usage := range l.domainUsage { l.cleanupDenyBuckets(now)
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
l.usageMux.Unlock()
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
} }
} }
} }
func (l *Logger) cleanupDomainUsage(now time.Time) {
l.usageMux.Lock()
defer l.usageMux.Unlock()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
}
func (l *Logger) cleanupDenyBuckets(now time.Time) {
l.denyMu.Lock()
defer l.denyMu.Unlock()
removed := 0
for key, bucket := range l.denyBuckets {
if now.Sub(bucket.lastLogged) > usageInactiveWindow {
delete(l.denyBuckets, key)
removed++
}
}
if removed > 0 {
l.logger.Debugf("cleaned up %d stale deny rate-limit entries", removed)
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/proxy/web"
) )
// Middleware wraps an HTTP handler to log access entries and resolve client IPs.
func (l *Logger) Middleware(next http.Handler) http.Handler { func (l *Logger) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip logging for internal proxy assets (CSS, JS, etc.) // Skip logging for internal proxy assets (CSS, JS, etc.)
@@ -47,8 +48,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
// Create a mutable struct to capture data from downstream handlers. // Create a mutable struct to capture data from downstream handlers.
// We pass a pointer in the context - the pointer itself flows down immutably, // We pass a pointer in the context - the pointer itself flows down immutably,
// but the struct it points to can be mutated by inner handlers. // but the struct it points to can be mutated by inner handlers.
capturedData := &proxy.CapturedData{RequestID: requestID} capturedData := proxy.NewCapturedData(requestID)
capturedData.SetClientIP(sourceIp) capturedData.SetClientIP(sourceIp)
ctx := proxy.WithCapturedData(r.Context(), capturedData) ctx := proxy.WithCapturedData(r.Context(), capturedData)
start := time.Now() start := time.Now()
@@ -66,8 +68,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
entry := logEntry{ entry := logEntry{
ID: requestID, ID: requestID,
ServiceId: capturedData.GetServiceId(), ServiceID: capturedData.GetServiceID(),
AccountID: capturedData.GetAccountId(), AccountID: capturedData.GetAccountID(),
Host: host, Host: host,
Path: r.URL.Path, Path: r.URL.Path,
DurationMs: duration.Milliseconds(), DurationMs: duration.Milliseconds(),
@@ -75,14 +77,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
ResponseCode: int32(sw.status), ResponseCode: int32(sw.status),
SourceIP: sourceIp, SourceIP: sourceIp,
AuthMechanism: capturedData.GetAuthMethod(), AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(), UserID: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
BytesUpload: bytesUpload, BytesUpload: bytesUpload,
BytesDownload: bytesDownload, BytesDownload: bytesDownload,
Protocol: ProtocolHTTP, Protocol: ProtocolHTTP,
} }
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID())
l.log(entry) l.log(entry)

View File

@@ -0,0 +1,69 @@
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ErrHeaderAuthFailed indicates that the header was present but the
// credential did not validate. Callers should return 401 instead of
// falling through to other auth schemes.
var ErrHeaderAuthFailed = errors.New("header authentication failed")
// Header implements header-based authentication. The proxy checks for the
// configured header in each request and validates its value via gRPC.
type Header struct {
id types.ServiceID
accountId types.AccountID
headerName string
client authenticator
}
// NewHeader creates a Header authentication scheme for the given header name.
func NewHeader(client authenticator, id types.ServiceID, accountId types.AccountID, headerName string) Header {
return Header{
id: id,
accountId: accountId,
headerName: headerName,
client: client,
}
}
// Type returns auth.MethodHeader.
func (Header) Type() auth.Method {
return auth.MethodHeader
}
// Authenticate checks for the configured header in the request. If absent,
// returns empty (unauthenticated). If present, validates via gRPC.
func (h Header) Authenticate(r *http.Request) (string, string, error) {
value := r.Header.Get(h.headerName)
if value == "" {
return "", "", nil
}
res, err := h.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: string(h.id),
AccountId: string(h.accountId),
Request: &proto.AuthenticateRequest_HeaderAuth{
HeaderAuth: &proto.HeaderAuthRequest{
HeaderValue: value,
HeaderName: h.headerName,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate header: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", "", ErrHeaderAuthFailed
}

View File

@@ -4,9 +4,12 @@ import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"html"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@@ -16,11 +19,16 @@ import (
"github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
// errValidationUnavailable indicates that session validation failed due to
// an infrastructure error (e.g. gRPC unavailable), not an invalid token.
var errValidationUnavailable = errors.New("session validation unavailable")
type authenticator interface { type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
} }
@@ -40,12 +48,14 @@ type Scheme interface {
Authenticate(*http.Request) (token string, promptData string, err error) Authenticate(*http.Request) (token string, promptData string, err error)
} }
// DomainConfig holds the authentication and restriction settings for a protected domain.
type DomainConfig struct { type DomainConfig struct {
Schemes []Scheme Schemes []Scheme
SessionPublicKey ed25519.PublicKey SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration SessionExpiration time.Duration
AccountID types.AccountID AccountID types.AccountID
ServiceID types.ServiceID ServiceID types.ServiceID
IPRestrictions *restrict.Filter
} }
type validationResult struct { type validationResult struct {
@@ -54,17 +64,18 @@ type validationResult struct {
DeniedReason string DeniedReason string
} }
// Middleware applies per-domain authentication and IP restriction checks.
type Middleware struct { type Middleware struct {
domainsMux sync.RWMutex domainsMux sync.RWMutex
domains map[string]DomainConfig domains map[string]DomainConfig
logger *log.Logger logger *log.Logger
sessionValidator SessionValidator sessionValidator SessionValidator
geo restrict.GeoResolver
} }
// NewMiddleware creates a new authentication middleware. // NewMiddleware creates a new authentication middleware. The sessionValidator is
// The sessionValidator is optional; if nil, OIDC session tokens will be validated // optional; if nil, OIDC session tokens are validated locally without group access checks.
// locally without group access checks. func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware {
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
@@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl
domains: make(map[string]DomainConfig), domains: make(map[string]DomainConfig),
logger: logger, logger: logger,
sessionValidator: sessionValidator, sessionValidator: sessionValidator,
geo: geo,
} }
} }
// Protect applies authentication middleware to the passed handler. // Protect wraps next with per-domain authentication and IP restriction checks.
// For each incoming request it will be checked against the middleware's // Requests whose Host is not registered pass through unchanged.
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
func (mw *Middleware) Protect(next http.Handler) http.Handler { func (mw *Middleware) Protect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host) host, _, err := net.SplitHostPort(r.Host)
@@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
config, exists := mw.getDomainConfig(host) config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists) mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through. if !exists {
if !exists || len(config.Schemes) == 0 {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
@@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
// Set account and service IDs in captured data for access logging. // Set account and service IDs in captured data for access logging.
setCapturedIDs(r, config) setCapturedIDs(r, config)
if !mw.checkIPRestrictions(w, r, config) {
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
if mw.handleOAuthCallbackError(w, r) { if mw.handleOAuthCallbackError(w, r) {
return return
} }
@@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return return
} }
if mw.forwardWithHeaderAuth(w, r, host, config, next) {
return
}
mw.authenticateWithSchemes(w, r, host, config) mw.authenticateWithSchemes(w, r, host, config)
}) })
} }
@@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
func setCapturedIDs(r *http.Request, config DomainConfig) { func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(config.AccountID) cd.SetAccountID(config.AccountID)
cd.SetServiceId(config.ServiceID) cd.SetServiceID(config.ServiceID)
} }
} }
// checkIPRestrictions validates the client IP against the domain's IP restrictions.
// Uses the resolved client IP from CapturedData (which accounts for trusted proxies)
// rather than r.RemoteAddr directly.
func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if config.IPRestrictions == nil {
return true
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
if verdict == restrict.Allow {
return true
}
reason := verdict.String()
mw.blockIPRestriction(r, reason)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr.
func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
if ip := cd.GetClientIP(); ip.IsValid() {
return ip
}
}
clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr)
if clientIPStr == "" {
clientIPStr = r.RemoteAddr
}
addr, err := netip.ParseAddr(clientIPStr)
if err != nil {
return netip.Addr{}
}
return addr.Unmap()
}
// blockIPRestriction sets captured data fields for an IP-restriction block event.
func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(reason)
}
mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr)
}
// handleOAuthCallbackError checks for error query parameters from an OAuth // handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present. // callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool { func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
@@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re
errDesc := r.URL.Query().Get("error_description") errDesc := r.URL.Query().Get("error_description")
if errDesc == "" { if errDesc == "" {
errDesc = "An error occurred during authentication" errDesc = "An error occurred during authentication"
} else {
errDesc = html.EscapeString(errDesc)
} }
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID) web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true return true
@@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
return true return true
} }
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
for _, scheme := range config.Schemes {
hdr, ok := scheme.(Header)
if !ok {
continue
}
handled := mw.tryHeaderScheme(w, r, host, config, hdr, next)
if handled {
return true
}
}
return false
}
func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool {
token, _, err := hdr.Authenticate(r)
if err != nil {
return mw.handleHeaderAuthError(w, r, err)
}
if token == "" {
return false
}
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "")
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return true
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetAuthMethod(auth.MethodHeader.String())
}
next.ServeHTTP(w, r)
return true
}
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return true
}
func setHeaderCapturedData(ctx context.Context, userID string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
}
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
}
// authenticateWithSchemes tries each configured auth scheme in order. // authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page. // On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) { func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
@@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
cd.SetOrigin(proxy.OriginAuth) cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String()) cd.SetAuthMethod(scheme.Type().String())
} }
http.Error(w, err.Error(), http.StatusBadRequest) status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return return
} }
@@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
return return
} }
expiration := config.SessionExpiration setSessionCookie(w, token, config.SessionExpiration)
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// setSessionCookie writes a session cookie with secure defaults.
func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) {
if expiration == 0 { if expiration == 0 {
expiration = auth.DefaultSessionExpiry expiration = auth.DefaultSessionExpiry
} }
@@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()), MaxAge: int(expiration.Seconds()),
}) })
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
// wasCredentialSubmitted checks if credentials were submitted for the given auth method. // wasCredentialSubmitted checks if credentials were submitted for the given auth method.
@@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
// session JWTs. Returns an error if the key is missing or invalid. // session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid // Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service. // exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error { func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
if len(schemes) == 0 { if len(schemes) == 0 {
mw.domainsMux.Lock() mw.domainsMux.Lock()
defer mw.domainsMux.Unlock() defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{ mw.domains[domain] = DomainConfig{
AccountID: accountID, AccountID: accountID,
ServiceID: serviceID, ServiceID: serviceID,
IPRestrictions: ipRestrictions,
} }
return nil return nil
} }
@@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
SessionExpiration: expiration, SessionExpiration: expiration,
AccountID: accountID, AccountID: accountID,
ServiceID: serviceID, ServiceID: serviceID,
IPRestrictions: ipRestrictions,
} }
return nil return nil
} }
// RemoveDomain unregisters authentication for the given domain.
func (mw *Middleware) RemoveDomain(domain string) { func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock() mw.domainsMux.Lock()
defer mw.domainsMux.Unlock() defer mw.domainsMux.Unlock()
delete(mw.domains, domain) delete(mw.domains, domain)
} }
// validateSessionToken validates a session token, optionally checking group access via gRPC. // validateSessionToken validates a session token. OIDC tokens with a configured
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access. // validator go through gRPC for group access checks; other methods validate locally.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) { func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil { if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{ resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host, Domain: host,
SessionToken: token, SessionToken: token,
}) })
if err != nil { if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed") return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err)
return nil, fmt.Errorf("session validation failed")
} }
if !resp.Valid { if !resp.Valid {
mw.logger.WithFields(log.Fields{ mw.logger.WithFields(log.Fields{
@@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
return &validationResult{UserID: resp.UserId, Valid: true}, nil return &validationResult{UserID: resp.UserId, Valid: true}, nil
} }
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey) userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -1,11 +1,14 @@
package auth package auth
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
@@ -14,10 +17,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/shared/management/proto"
) )
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair { func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
@@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler {
} }
func TestAddDomain_ValidKey(t *testing.T) { func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "") err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
require.NoError(t, err) require.NoError(t, err)
mw.domainsMux.RLock() mw.domainsMux.RLock()
@@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) {
} }
func TestAddDomain_EmptyKey(t *testing.T) { func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "") err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size") assert.Contains(t, err.Error(), "invalid session public key size")
@@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) {
} }
func TestAddDomain_InvalidBase64(t *testing.T) { func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "") err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key") assert.Contains(t, err.Error(), "decode session public key")
@@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
} }
func TestAddDomain_WrongKeySize(t *testing.T) { func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "") err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size") assert.Contains(t, err.Error(), "invalid session public key size")
@@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
} }
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "") err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
require.NoError(t, err, "domains with no auth schemes should not require a key") require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock() mw.domainsMux.RLock()
@@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
} }
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp1 := generateTestKeyPair(t) kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
mw.domainsMux.RLock() mw.domainsMux.RLock()
config := mw.domains["example.com"] config := mw.domains["example.com"]
@@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
} }
func TestRemoveDomain(t *testing.T) { func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
mw.RemoveDomain("example.com") mw.RemoveDomain("example.com")
@@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) {
} }
func TestProtect_UnknownDomainPassesThrough(t *testing.T) { func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil) req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
@@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
} }
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
@@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
} }
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
} }
func TestProtect_HostWithPortIsMatched(t *testing.T) { func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
} }
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err) require.NoError(t, err)
capturedData := &proxy.CapturedData{} capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context()) cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd) require.NotNil(t, cd)
@@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
} }
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Sign a token that expired 1 second ago. // Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second) token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
@@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
} }
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Token signed for a different domain audience. // Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour) token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
@@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
} }
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp1 := generateTestKeyPair(t) kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t) kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
// Token signed with a different private key. // Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
@@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
} }
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour) token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
@@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil return "", "pin", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
} }
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{ scheme := &stubScheme{
@@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil return "", "pin", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
@@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
} }
func TestProtect_MultipleSchemes(t *testing.T) { func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour) token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
@@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil return "", "password", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
} }
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
// Return a garbage token that won't validate. // Return a garbage token that won't validate.
@@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil return "invalid-jwt-token", "", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
@@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
} }
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
// 32 random bytes that happen to be valid base64 and correct size // 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise. // but are actually a valid ed25519 public key length-wise.
@@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes) key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "") err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
require.NoError(t, err, "any 32-byte key should be accepted at registration time") require.NoError(t, err, "any 32-byte key should be accepted at registration time")
} }
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Attempt to overwrite with an invalid key. // Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "") err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
require.Error(t, err) require.Error(t, err)
// The original valid config should still be intact. // The original valid config should still be intact.
@@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
} }
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
// Scheme that always fails authentication (returns empty token) // Scheme that always fails authentication (returns empty token)
@@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil return "", "pin", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{} capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
// Submit wrong PIN - should capture auth method // Submit wrong PIN - should capture auth method
@@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
} }
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{ scheme := &stubScheme{
@@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil return "", "password", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{} capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
// Submit wrong password - should capture auth method // Submit wrong password - should capture auth method
@@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
} }
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) { func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil) mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t) kp := generateTestKeyPair(t)
scheme := &stubScheme{ scheme := &stubScheme{
@@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil return "", "pin", nil
}, },
} }
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{} capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler()) handler := mw.Protect(newPassthroughHandler())
// No credentials submitted - should not capture auth method // No credentials submitted - should not capture auth method
@@ -658,3 +664,271 @@ func TestWasCredentialSubmitted(t *testing.T) {
}) })
} }
} }
func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
remoteAddr string
wantCode int
}{
{"unparsable address denies", "not-an-ip:1234", http.StatusForbidden},
{"empty address denies", "", http.StatusForbidden},
{"allowed address passes", "10.1.2.3:5678", http.StatusOK},
{"denied address blocked", "192.168.1.1:5678", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = tt.remoteAddr
req.Host = "example.com"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
}
func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
// When CapturedData is set (by the access log middleware, which resolves
// trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr.
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// RemoteAddr is a trusted proxy, but CapturedData has the real client IP.
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "10.0.0.1:5000"
req.Host = "example.com"
cd := proxy.NewCapturedData("")
cd.SetClientIP(netip.MustParseAddr("203.0.113.50"))
ctx := proxy.WithCapturedData(req.Context(), cd)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)")
// Same request but CapturedData has a blocked IP.
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req2.RemoteAddr = "203.0.113.50:5000"
req2.Host = "example.com"
cd2 := proxy.NewCapturedData("")
cd2.SetClientIP(netip.MustParseAddr("10.0.0.1"))
ctx2 := proxy.WithCapturedData(req2.Context(), cd2)
req2 = req2.WithContext(ctx2)
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)")
}
func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
// Geo is nil, country restrictions are configured: must deny (fail-close).
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(nil, nil, []string{"US"}, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "1.2.3.4:5678"
req.Host = "example.com"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
}
// mockAuthenticator is a minimal mock for the authenticator gRPC interface
// used by the Header scheme.
type mockAuthenticator struct {
fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error)
}
func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) {
return m.fn(ctx, in)
}
// newHeaderSchemeWithToken creates a Header scheme backed by a mock that
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && ha.GetHeaderValue() == expectedValue {
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
return &proto.AuthenticateResponse{Success: false}, nil
}}
return NewHeader(mock, "svc1", "acc1", headerName)
}
func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)
req.Header.Set("X-API-Key", "secret-key")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "ok", rec.Body.String())
// Session cookie should be set.
var sessionCookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth")
assert.True(t, sessionCookie.HttpOnly)
assert.True(t, sessionCookie.Secure)
assert.Equal(t, "header-user", capturedData.GetUserID())
assert.Equal(t, "header", capturedData.GetAuthMethod())
}
func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
// No X-API-Key header: should fall through to PIN login page (401).
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page")
}
func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("X-API-Key", "wrong-key")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "header", capturedData.GetAuthMethod())
}
func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("X-API-Key", "some-key")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadGateway, rec.Code)
}
func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request with header auth.
req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req1.Header.Set("X-API-Key", "secret-key")
req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData("")))
rec1 := httptest.NewRecorder()
handler.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
// Extract session cookie.
var sessionCookie *http.Cookie
for _, c := range rec1.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie)
// Second request with only the session cookie (no header).
capturedData2 := proxy.NewCapturedData("")
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil)
req2.AddCookie(sessionCookie)
req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2))
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
assert.Equal(t, http.StatusOK, rec2.Code)
assert.Equal(t, "header-user", capturedData2.GetUserID())
assert.Equal(t, "header", capturedData2.GetAuthMethod())
}

View File

@@ -0,0 +1,264 @@
package geolocation
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
mmdbInnerName = "GeoLite2-City.mmdb"
downloadTimeout = 2 * time.Minute
maxMMDBSize = 256 << 20 // 256 MB
)
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
// it downloads from pkgs.netbird.io with SHA256 verification.
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
if err := os.MkdirAll(dataDir, 0o755); err != nil {
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
}
pattern := filepath.Join(dataDir, mmdbGlob)
if files, _ := filepath.Glob(pattern); len(files) > 0 {
mmdbPath := files[len(files)-1]
logger.Debugf("using existing geolocation database: %s", mmdbPath)
return mmdbPath, nil
}
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
return downloadMMDB(logger, dataDir)
}
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
client := &http.Client{Timeout: downloadTimeout}
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
if err != nil {
return "", fmt.Errorf("get remote filename: %w", err)
}
mmdbFilename := deriveMMDBFilename(datedName)
mmdbPath := filepath.Join(dataDir, mmdbFilename)
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
if err != nil {
return "", fmt.Errorf("create temp directory: %w", err)
}
defer os.RemoveAll(tmp)
checksumFile := filepath.Join(tmp, "checksum.sha256")
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
return "", fmt.Errorf("download checksum: %w", err)
}
expectedHash, err := readChecksumFile(checksumFile)
if err != nil {
return "", fmt.Errorf("read checksum: %w", err)
}
tarFile := filepath.Join(tmp, datedName)
logger.Debugf("downloading geolocation database (%s)", datedName)
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
return "", fmt.Errorf("download database: %w", err)
}
if err := verifySHA256(tarFile, expectedHash); err != nil {
return "", fmt.Errorf("verify database checksum: %w", err)
}
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
return "", fmt.Errorf("extract database: %w", err)
}
logger.Infof("geolocation database downloaded: %s", mmdbPath)
return mmdbPath, nil
}
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
func deriveMMDBFilename(tarName string) string {
base, _, _ := strings.Cut(tarName, ".")
if !strings.Contains(base, "_") {
return "GeoLite2-City.mmdb"
}
return base + ".mmdb"
}
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
resp, err := client.Head(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
}
cd := resp.Header.Get("Content-Disposition")
if cd == "" {
return "", errors.New("no Content-Disposition header")
}
_, params, err := mime.ParseMediaType(cd)
if err != nil {
return "", fmt.Errorf("parse Content-Disposition: %w", err)
}
name := filepath.Base(params["filename"])
if name == "" || name == "." {
return "", errors.New("no filename in Content-Disposition")
}
return name, nil
}
func downloadToFile(client *http.Client, url, dest string) error {
resp, err := client.Get(url) //nolint:gosec
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
f, err := os.Create(dest) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
return err
}
return nil
}
func readChecksumFile(path string) (string, error) {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", errors.New("empty checksum file")
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actual := fmt.Sprintf("%x", h.Sum(nil))
if actual != expected {
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
}
return nil
}
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
f, err := os.Open(tarGZPath) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gz.Close()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
}
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
return err
}
return nil
}
}
return fmt.Errorf("%s not found in archive", mmdbInnerName)
}
// extractToFileAtomic writes r to a temporary file in the same directory as
// destPath, then renames it into place so a crash never leaves a truncated file.
func extractToFileAtomic(r io.Reader, destPath string) error {
dir := filepath.Dir(destPath)
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmp.Name()
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
if closeErr := tmp.Close(); closeErr != nil {
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
}
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("write mmdb: %w", err)
}
if err := tmp.Close(); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("close temp file: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("rename to %s: %w", destPath, err)
}
return nil
}

View File

@@ -0,0 +1,152 @@
// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases.
package geolocation
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"github.com/oschwald/maxminddb-golang"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables geolocation lookups entirely when set to a truthy value.
EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION"
mmdbGlob = "GeoLite2-City_*.mmdb"
)
type record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
City struct {
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"city"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
}
// Result holds the outcome of a geo lookup.
type Result struct {
CountryCode string
CityName string
SubdivisionCode string
SubdivisionName string
}
// Lookup provides IP geolocation lookups.
type Lookup struct {
mu sync.RWMutex
db *maxminddb.Reader
logger *log.Logger
}
// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir.
// Returns nil without error if geolocation is disabled via environment
// variable, no data directory is configured, or the download fails
// (graceful degradation: country restrictions will deny all requests).
func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) {
if isDisabledByEnv(logger) {
logger.Info("geolocation disabled via environment variable")
return nil, nil //nolint:nilnil
}
if dataDir == "" {
return nil, nil //nolint:nilnil
}
mmdbPath, err := ensureMMDB(logger, dataDir)
if err != nil {
logger.Warnf("geolocation database unavailable: %v", err)
logger.Warn("country-based access restrictions will deny all requests until a database is available")
return nil, nil //nolint:nilnil
}
db, err := maxminddb.Open(mmdbPath)
if err != nil {
return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err)
}
logger.Infof("geolocation database loaded from %s", mmdbPath)
return &Lookup{db: db, logger: logger}, nil
}
// LookupAddr returns the country ISO code and city name for the given IP.
// Returns an empty Result if the database is nil or the lookup fails.
func (l *Lookup) LookupAddr(addr netip.Addr) Result {
if l == nil {
return Result{}
}
l.mu.RLock()
defer l.mu.RUnlock()
if l.db == nil {
return Result{}
}
addr = addr.Unmap()
var rec record
if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil {
l.logger.Debugf("geolocation lookup %s: %v", addr, err)
return Result{}
}
r := Result{
CountryCode: rec.Country.ISOCode,
CityName: rec.City.Names.En,
}
if len(rec.Subdivisions) > 0 {
r.SubdivisionCode = rec.Subdivisions[0].ISOCode
r.SubdivisionName = rec.Subdivisions[0].Names.En
}
return r
}
// Available reports whether the lookup has a loaded database.
func (l *Lookup) Available() bool {
if l == nil {
return false
}
l.mu.RLock()
defer l.mu.RUnlock()
return l.db != nil
}
// Close releases the database resources.
func (l *Lookup) Close() error {
if l == nil {
return nil
}
l.mu.Lock()
defer l.mu.Unlock()
if l.db != nil {
err := l.db.Close()
l.db = nil
return err
}
return nil
}
func isDisabledByEnv(logger *log.Logger) bool {
val := os.Getenv(EnvDisable)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
logger.Warnf("parse %s=%q: %v", EnvDisable, val, err)
return false
}
return disabled
}

View File

@@ -11,8 +11,6 @@ import (
type requestContextKey string type requestContextKey string
const ( const (
serviceIdKey requestContextKey = "serviceId"
accountIdKey requestContextKey = "accountId"
capturedDataKey requestContextKey = "capturedData" capturedDataKey requestContextKey = "capturedData"
) )
@@ -47,112 +45,117 @@ func (o ResponseOrigin) String() string {
// to pass data back up the middleware chain. // to pass data back up the middleware chain.
type CapturedData struct { type CapturedData struct {
mu sync.RWMutex mu sync.RWMutex
RequestID string requestID string
ServiceId types.ServiceID serviceID types.ServiceID
AccountId types.AccountID accountID types.AccountID
Origin ResponseOrigin origin ResponseOrigin
ClientIP netip.Addr clientIP netip.Addr
UserID string userID string
AuthMethod string authMethod string
} }
// GetRequestID safely gets the request ID // NewCapturedData creates a CapturedData with the given request ID.
func NewCapturedData(requestID string) *CapturedData {
return &CapturedData{requestID: requestID}
}
// GetRequestID returns the request ID.
func (c *CapturedData) GetRequestID() string { func (c *CapturedData) GetRequestID() string {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.RequestID return c.requestID
} }
// SetServiceId safely sets the service ID // SetServiceID sets the service ID.
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) { func (c *CapturedData) SetServiceID(serviceID types.ServiceID) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.ServiceId = serviceId c.serviceID = serviceID
} }
// GetServiceId safely gets the service ID // GetServiceID returns the service ID.
func (c *CapturedData) GetServiceId() types.ServiceID { func (c *CapturedData) GetServiceID() types.ServiceID {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.ServiceId return c.serviceID
} }
// SetAccountId safely sets the account ID // SetAccountID sets the account ID.
func (c *CapturedData) SetAccountId(accountId types.AccountID) { func (c *CapturedData) SetAccountID(accountID types.AccountID) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.AccountId = accountId c.accountID = accountID
} }
// GetAccountId safely gets the account ID // GetAccountID returns the account ID.
func (c *CapturedData) GetAccountId() types.AccountID { func (c *CapturedData) GetAccountID() types.AccountID {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.AccountId return c.accountID
} }
// SetOrigin safely sets the response origin // SetOrigin sets the response origin.
func (c *CapturedData) SetOrigin(origin ResponseOrigin) { func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.Origin = origin c.origin = origin
} }
// GetOrigin safely gets the response origin // GetOrigin returns the response origin.
func (c *CapturedData) GetOrigin() ResponseOrigin { func (c *CapturedData) GetOrigin() ResponseOrigin {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.Origin return c.origin
} }
// SetClientIP safely sets the resolved client IP. // SetClientIP sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip netip.Addr) { func (c *CapturedData) SetClientIP(ip netip.Addr) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.ClientIP = ip c.clientIP = ip
} }
// GetClientIP safely gets the resolved client IP. // GetClientIP returns the resolved client IP.
func (c *CapturedData) GetClientIP() netip.Addr { func (c *CapturedData) GetClientIP() netip.Addr {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.ClientIP return c.clientIP
} }
// SetUserID safely sets the authenticated user ID. // SetUserID sets the authenticated user ID.
func (c *CapturedData) SetUserID(userID string) { func (c *CapturedData) SetUserID(userID string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.UserID = userID c.userID = userID
} }
// GetUserID safely gets the authenticated user ID. // GetUserID returns the authenticated user ID.
func (c *CapturedData) GetUserID() string { func (c *CapturedData) GetUserID() string {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.UserID return c.userID
} }
// SetAuthMethod safely sets the authentication method used. // SetAuthMethod sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) { func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.AuthMethod = method c.authMethod = method
} }
// GetAuthMethod safely gets the authentication method used. // GetAuthMethod returns the authentication method used.
func (c *CapturedData) GetAuthMethod() string { func (c *CapturedData) GetAuthMethod() string {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.AuthMethod return c.authMethod
} }
// WithCapturedData adds a CapturedData struct to the context // WithCapturedData adds a CapturedData struct to the context.
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
return context.WithValue(ctx, capturedDataKey, data) return context.WithValue(ctx, capturedDataKey, data)
} }
// CapturedDataFromContext retrieves the CapturedData from context // CapturedDataFromContext retrieves the CapturedData from context.
func CapturedDataFromContext(ctx context.Context) *CapturedData { func CapturedDataFromContext(ctx context.Context) *CapturedData {
v := ctx.Value(capturedDataKey) v := ctx.Value(capturedDataKey)
data, ok := v.(*CapturedData) data, ok := v.(*CapturedData)
@@ -161,28 +164,3 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
} }
return data return data
} }
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(types.ServiceID)
if !ok {
return ""
}
return serviceId
}
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
return context.WithValue(ctx, accountIdKey, accountId)
}
func AccountIdFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIdKey)
accountId, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountId
}

View File

@@ -66,19 +66,16 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// Set the serviceId in the context for later retrieval. ctx := r.Context()
ctx := withServiceId(r.Context(), result.serviceID) // Set the account ID in the context for the roundtripper to use.
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID) ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes). // Populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct // This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it. // pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil { if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(result.serviceID) capturedData.SetServiceID(result.serviceID)
capturedData.SetAccountId(result.accountID) capturedData.SetAccountID(result.accountID)
} }
pt := result.target pt := result.target
@@ -96,10 +93,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
rp := &httputil.ReverseProxy{ rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders), Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders),
Transport: p.transport, Transport: p.transport,
FlushInterval: -1, FlushInterval: -1,
ErrorHandler: proxyErrorHandler, ErrorHandler: p.proxyErrorHandler,
} }
if result.rewriteRedirects { if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
@@ -113,7 +110,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// When passHostHeader is true, the original client Host header is preserved // When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address. // instead of being rewritten to the backend's address.
// The pathRewrite parameter controls how the request path is transformed. // The pathRewrite parameter controls how the request path is transformed.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) { func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string, stripAuthHeaders []string) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) { return func(r *httputil.ProxyRequest) {
switch pathRewrite { switch pathRewrite {
case PathRewritePreserve: case PathRewritePreserve:
@@ -137,6 +134,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Host = target.Host r.Out.Host = target.Host
} }
for _, h := range stripAuthHeaders {
r.Out.Header.Del(h)
}
for k, v := range customHeaders { for k, v := range customHeaders {
r.Out.Header.Set(k, v) r.Out.Header.Set(k, v)
} }
@@ -305,7 +306,7 @@ func extractForwardedPort(host, resolvedProto string) string {
// proxyErrorHandler handles errors from the reverse proxy and serves // proxyErrorHandler handles errors from the reverse proxy and serves
// user-friendly error pages instead of raw error responses. // user-friendly error pages instead of raw error responses.
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { func (p *ReverseProxy) proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
if cd := CapturedDataFromContext(r.Context()); cd != nil { if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginProxyError) cd.SetOrigin(OriginProxyError)
} }
@@ -313,7 +314,7 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
clientIP := getClientIP(r) clientIP := getClientIP(r)
title, message, code, status := classifyProxyError(err) title, message, code, status := classifyProxyError(err)
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v", p.logger.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err) requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
web.ServeErrorPage(w, r, code, title, message, requestID, status) web.ServeErrorPage(w, r, code, title, message, requestID, status)

View File

@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) { t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr) rewrite(pr)
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
}) })
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) { t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr) rewrite(pr)
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
func TestRewriteFunc_XForwardedForStripping(t *testing.T) { func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) { t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) { t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000") pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) { t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) { t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{} pr.In.TLS = &tls.ConnectionState{}
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) { t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects https from TLS", func(t *testing.T) { t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{} pr.In.TLS = &tls.ConnectionState{}
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects http without TLS", func(t *testing.T) { t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced proto overrides TLS detection", func(t *testing.T) { t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"} p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https // No TLS, but forced to https
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced http proto", func(t *testing.T) { t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"} p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{} pr.In.TLS = &tls.ConnectionState{}
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
func TestRewriteFunc_SessionCookieStripping(t *testing.T) { func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("strips nb_session cookie", func(t *testing.T) { t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) { func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("strips session_token query parameter", func(t *testing.T) { t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("rewrites URL to target with path prefix", func(t *testing.T) { t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app") target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) { t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app") target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) { t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app") target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("appends to X-Forwarded-For", func(t *testing.T) { t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Real-IP", func(t *testing.T) { t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) { t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) { t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com") pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) { t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https") pr.In.Header.Set("X-Forwarded-Proto", "https")
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) { t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443") pr.In.Header.Set("X-Forwarded-Port", "8443")
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) { t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) { t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) { t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1") pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) { t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) { t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) { t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise" // Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise") target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("subpath under prefix also preserved", func(t *testing.T) { t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise") target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// What the behavior WOULD be if target URL had no path (true stripping) // What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) { t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443") target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) { t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443") target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// Root path "/" — no stripping expected // Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) { t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/") target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -551,7 +551,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
t.Run("preserve keeps full request path", func(t *testing.T) { t.Run("preserve keeps full request path", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil) rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil, nil)
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -561,7 +561,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
}) })
t.Run("preserve with root matchedPath", func(t *testing.T) { t.Run("preserve with root matchedPath", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil) rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil, nil)
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -579,7 +579,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
"X-Custom-Auth": "token-abc", "X-Custom-Auth": "token-abc",
"X-Env": "production", "X-Env": "production",
} }
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers) rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -589,7 +589,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
}) })
t.Run("nil customHeaders is fine", func(t *testing.T) { t.Run("nil customHeaders is fine", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil) rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -599,7 +599,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
t.Run("custom headers override existing request headers", func(t *testing.T) { t.Run("custom headers override existing request headers", func(t *testing.T) {
headers := map[string]string{"X-Override": "new-value"} headers := map[string]string{"X-Override": "new-value"}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers) rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("X-Override", "old-value") pr.In.Header.Set("X-Override", "old-value")
@@ -609,11 +609,38 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
}) })
} }
func TestRewriteFunc_StripsAuthorizationHeader(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("strips incoming Authorization when no custom Authorization set", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, []string{"Authorization"})
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("Authorization", "Bearer proxy-token")
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("Authorization"), "Authorization should be stripped")
})
t.Run("custom Authorization replaces incoming", func(t *testing.T) {
headers := map[string]string{"Authorization": "Basic YmFja2VuZDpzZWNyZXQ="}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, []string{"Authorization"})
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("Authorization", "Bearer proxy-token")
rewrite(pr)
assert.Equal(t, "Basic YmFja2VuZDpzZWNyZXQ=", pr.Out.Header.Get("Authorization"),
"backend Authorization from custom headers should be set")
})
}
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) { func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}) rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}, nil)
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)

View File

@@ -38,6 +38,11 @@ type Mapping struct {
Paths map[string]*PathTarget Paths map[string]*PathTarget
PassHostHeader bool PassHostHeader bool
RewriteRedirects bool RewriteRedirects bool
// StripAuthHeaders are header names used for header-based auth.
// These headers are stripped from requests before forwarding.
StripAuthHeaders []string
// sortedPaths caches the paths sorted by length (longest first).
sortedPaths []string
} }
type targetResult struct { type targetResult struct {
@@ -47,6 +52,7 @@ type targetResult struct {
accountID types.AccountID accountID types.AccountID
passHostHeader bool passHostHeader bool
rewriteRedirects bool rewriteRedirects bool
stripAuthHeaders []string
} }
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) { func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
@@ -65,16 +71,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
return targetResult{}, false return targetResult{}, false
} }
// Sort paths by length (longest first) in a naive attempt to match the most specific route first. for _, path := range m.sortedPaths {
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) { if strings.HasPrefix(req.URL.Path, path) {
pt := m.Paths[path] pt := m.Paths[path]
if pt == nil || pt.URL == nil { if pt == nil || pt.URL == nil {
@@ -89,6 +86,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
accountID: m.AccountID, accountID: m.AccountID,
passHostHeader: m.PassHostHeader, passHostHeader: m.PassHostHeader,
rewriteRedirects: m.RewriteRedirects, rewriteRedirects: m.RewriteRedirects,
stripAuthHeaders: m.StripAuthHeaders,
}, true }, true
} }
} }
@@ -96,7 +94,18 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
return targetResult{}, false return targetResult{}, false
} }
// AddMapping registers a host-to-backend mapping for the reverse proxy.
func (p *ReverseProxy) AddMapping(m Mapping) { func (p *ReverseProxy) AddMapping(m Mapping) {
// Sort paths longest-first to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
m.sortedPaths = paths
p.mappingsMux.Lock() p.mappingsMux.Lock()
defer p.mappingsMux.Unlock() defer p.mappingsMux.Unlock()
p.mappings[m.Host] = m p.mappings[m.Host] = m

View File

@@ -0,0 +1,183 @@
// Package restrict provides connection-level access control based on
// IP CIDR ranges and geolocation (country codes).
package restrict
import (
"net/netip"
"slices"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
)
// GeoResolver resolves an IP address to geographic information.
type GeoResolver interface {
LookupAddr(addr netip.Addr) geolocation.Result
Available() bool
}
// Filter evaluates IP restrictions. CIDR checks are performed first
// (cheap), followed by country lookups (more expensive) only when needed.
type Filter struct {
AllowedCIDRs []netip.Prefix
BlockedCIDRs []netip.Prefix
AllowedCountries []string
BlockedCountries []string
}
// ParseFilter builds a Filter from the raw string slices. Returns nil
// if all slices are empty.
func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter {
if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 &&
len(allowedCountries) == 0 && len(blockedCountries) == 0 {
return nil
}
f := &Filter{
AllowedCountries: normalizeCountryCodes(allowedCountries),
BlockedCountries: normalizeCountryCodes(blockedCountries),
}
for _, cidr := range allowedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err)
continue
}
f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked())
}
for _, cidr := range blockedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err)
continue
}
f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked())
}
return f
}
func normalizeCountryCodes(codes []string) []string {
if len(codes) == 0 {
return nil
}
out := make([]string, len(codes))
for i, c := range codes {
out[i] = strings.ToUpper(c)
}
return out
}
// Verdict is the result of an access check.
type Verdict int
const (
// Allow indicates the address passed all checks.
Allow Verdict = iota
// DenyCIDR indicates the address was blocked by a CIDR rule.
DenyCIDR
// DenyCountry indicates the address was blocked by a country rule.
DenyCountry
// DenyGeoUnavailable indicates that country restrictions are configured
// but the geo lookup is unavailable.
DenyGeoUnavailable
)
// String returns the deny reason string matching the HTTP auth mechanism names.
func (v Verdict) String() string {
switch v {
case Allow:
return "allow"
case DenyCIDR:
return "ip_restricted"
case DenyCountry:
return "country_restricted"
case DenyGeoUnavailable:
return "geo_unavailable"
default:
return "unknown"
}
}
// Check evaluates whether addr is permitted. CIDR rules are evaluated
// first because they are O(n) prefix comparisons. Country rules run
// only when CIDR checks pass and require a geo lookup.
func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict {
if f == nil {
return Allow
}
// Normalize v4-mapped-v6 (e.g. ::ffff:10.1.2.3) to plain v4 so that
// IPv4 CIDR rules match regardless of how the address was received.
addr = addr.Unmap()
if v := f.checkCIDR(addr); v != Allow {
return v
}
return f.checkCountry(addr, geo)
}
func (f *Filter) checkCIDR(addr netip.Addr) Verdict {
if len(f.AllowedCIDRs) > 0 {
allowed := false
for _, prefix := range f.AllowedCIDRs {
if prefix.Contains(addr) {
allowed = true
break
}
}
if !allowed {
return DenyCIDR
}
}
for _, prefix := range f.BlockedCIDRs {
if prefix.Contains(addr) {
return DenyCIDR
}
}
return Allow
}
func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict {
if len(f.AllowedCountries) == 0 && len(f.BlockedCountries) == 0 {
return Allow
}
if geo == nil || !geo.Available() {
return DenyGeoUnavailable
}
result := geo.LookupAddr(addr)
if result.CountryCode == "" {
// Unknown country: deny if an allowlist is active, allow otherwise.
// Blocklists are best-effort: unknown countries pass through since
// the default policy is allow.
if len(f.AllowedCountries) > 0 {
return DenyCountry
}
return Allow
}
if len(f.AllowedCountries) > 0 {
if !slices.Contains(f.AllowedCountries, result.CountryCode) {
return DenyCountry
}
}
if slices.Contains(f.BlockedCountries, result.CountryCode) {
return DenyCountry
}
return Allow
}
// HasRestrictions returns true if any restriction rules are configured.
func (f *Filter) HasRestrictions() bool {
if f == nil {
return false
}
return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 ||
len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0
}

View File

@@ -0,0 +1,278 @@
package restrict
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
)
type mockGeo struct {
countries map[string]string
}
func (m *mockGeo) LookupAddr(addr netip.Addr) geolocation.Result {
return geolocation.Result{CountryCode: m.countries[addr.String()]}
}
func (m *mockGeo) Available() bool { return true }
func newMockGeo(entries map[string]string) *mockGeo {
return &mockGeo{countries: entries}
}
func TestFilter_Check_NilFilter(t *testing.T) {
var f *Filter
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
}
func TestFilter_Check_AllowedCIDR(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_BlockedCIDR(t *testing.T) {
f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil)
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "not in allowlist")
}
func TestFilter_Check_AllowedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
f := ParseFilter(nil, nil, []string{"US", "DE"}, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
}
func TestFilter_Check_BlockedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "CN",
"2.2.2.2": "RU",
"3.3.3.3": "US",
})
f := ParseFilter(nil, nil, nil, []string{"CN", "RU"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "US not in blocklist")
}
func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
// Allow US and DE, but block DE explicitly.
f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"})
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
}
func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
})
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active")
}
func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "CN",
})
f := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active")
}
func TestFilter_Check_CountryWithoutGeo(t *testing.T) {
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist")
}
func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) {
f := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist")
}
func TestFilter_Check_GeoUnavailable(t *testing.T) {
geo := &unavailableGeo{}
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist")
f2 := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist")
}
func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
// CIDR-only filter should never touch geo, so nil geo is fine.
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) {
geo := newMockGeo(map[string]string{
"10.1.2.3": "CN",
"10.2.3.4": "US",
})
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), geo), "CIDR denied before country check")
}
func TestParseFilter_Empty(t *testing.T) {
f := ParseFilter(nil, nil, nil, nil)
assert.Nil(t, f)
}
func TestParseFilter_InvalidCIDR(t *testing.T) {
f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil)
assert.NotNil(t, f)
assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped")
assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), f.AllowedCIDRs[0])
}
func TestFilter_HasRestrictions(t *testing.T) {
assert.False(t, (*Filter)(nil).HasRestrictions())
assert.False(t, (&Filter{}).HasRestrictions())
assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions())
assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions())
}
func TestFilter_Check_IPv6CIDR(t *testing.T) {
f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 addr not in v6 allowlist")
}
func TestFilter_Check_IPv4MappedIPv6(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
// A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR.
v4mapped := netip.MustParseAddr("::ffff:10.1.2.3")
assert.True(t, v4mapped.Is4In6(), "precondition: address is v4-in-v6")
assert.Equal(t, Allow, f.Check(v4mapped, nil), "v4-mapped-v6 must match v4 CIDR after Unmap")
v4mappedOutside := netip.MustParseAddr("::ffff:192.168.1.1")
assert.Equal(t, DenyCIDR, f.Check(v4mappedOutside, nil), "v4-mapped-v6 outside v4 CIDR")
}
func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("::ffff:10.1.2.3"), nil), "v4-mapped matches v4 CIDR")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "v4 not in either CIDR")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("fe80::1"), nil), "v6 not in either CIDR")
}
func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) {
// 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24.
f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil)
assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0])
// Verify it still matches correctly.
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.100"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("1.1.2.1"), nil))
}
func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
tests := []struct {
name string
allowedCountries []string
blockedCountries []string
addr string
want Verdict
}{
{
name: "lowercase allowlist matches uppercase MaxMind code",
allowedCountries: []string{"us", "de"},
addr: "1.1.1.1",
want: Allow,
},
{
name: "mixed-case allowlist matches",
allowedCountries: []string{"Us", "dE"},
addr: "2.2.2.2",
want: Allow,
},
{
name: "lowercase allowlist rejects non-matching country",
allowedCountries: []string{"us", "de"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "lowercase blocklist blocks matching country",
blockedCountries: []string{"cn"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "mixed-case blocklist blocks matching country",
blockedCountries: []string{"Cn"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "lowercase blocklist does not block non-matching country",
blockedCountries: []string{"cn"},
addr: "1.1.1.1",
want: Allow,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries)
got := f.Check(netip.MustParseAddr(tc.addr), geo)
assert.Equal(t, tc.want, got)
})
}
}
// unavailableGeo simulates a GeoResolver whose database is not loaded.
type unavailableGeo struct{}
func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} }
func (u *unavailableGeo) Available() bool { return false }

View File

@@ -7,12 +7,14 @@ import (
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
"strings"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
) )
@@ -20,6 +22,10 @@ import (
// timeout is configured. // timeout is configured.
const defaultDialTimeout = 30 * time.Second const defaultDialTimeout = 30 * time.Second
// errAccessRestricted is returned by relayTCP for access restriction
// denials so callers can skip warn-level logging (already logged at debug).
var errAccessRestricted = errors.New("rejected by access restrictions")
// SNIHost is a typed key for SNI hostname lookups. // SNIHost is a typed key for SNI hostname lookups.
type SNIHost string type SNIHost string
@@ -64,6 +70,11 @@ type Route struct {
// DialTimeout overrides the default dial timeout for this route. // DialTimeout overrides the default dial timeout for this route.
// Zero uses defaultDialTimeout. // Zero uses defaultDialTimeout.
DialTimeout time.Duration DialTimeout time.Duration
// SessionIdleTimeout overrides the default idle timeout for relay connections.
// Zero uses DefaultIdleTimeout.
SessionIdleTimeout time.Duration
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
} }
// l4Logger sends layer-4 access log entries to the management server. // l4Logger sends layer-4 access log entries to the management server.
@@ -99,6 +110,7 @@ type Router struct {
drainDone chan struct{} drainDone chan struct{}
observer RelayObserver observer RelayObserver
accessLog l4Logger accessLog l4Logger
geo restrict.GeoResolver
// svcCtxs tracks a context per service ID. All relay goroutines for a // svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately. // service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context svcCtxs map[types.ServiceID]context.Context
@@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener {
// stored and resolved by priority at lookup time (HTTP > TCP). // stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback. // Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
func (r *Router) AddRoute(host SNIHost, route Route) { func (r *Router) AddRoute(host SNIHost, route Route) {
host = SNIHost(strings.ToLower(string(host)))
if host == "" { if host == "" {
return return
} }
@@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) {
// Active relay connections for the service are closed immediately. // Active relay connections for the service are closed immediately.
// If other routes remain for the host, they are preserved. // If other routes remain for the host, they are preserved.
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) { func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
host = SNIHost(strings.ToLower(string(host)))
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
@@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
return return
} }
host := SNIHost(sni) host := SNIHost(strings.ToLower(sni))
route, ok := r.lookupRoute(host) route, ok := r.lookupRoute(host)
if !ok { if !ok {
r.handleUnmatched(ctx, wrapped) r.handleUnmatched(ctx, wrapped)
@@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
} }
if err := r.relayTCP(ctx, wrapped, host, route); err != nil { if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
r.logger.WithFields(log.Fields{ if !errors.Is(err, errAccessRestricted) {
"sni": host, r.logger.WithFields(log.Fields{
"service_id": route.ServiceID, "sni": host,
"target": route.Target, "service_id": route.ServiceID,
}).Warnf("TCP relay: %v", err) "target": route.Target,
}).Warnf("TCP relay: %v", err)
}
_ = wrapped.Close() _ = wrapped.Close()
} }
} }
@@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
if fb != nil { if fb != nil {
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
r.logger.WithFields(log.Fields{ if !errors.Is(err, errAccessRestricted) {
"service_id": fb.ServiceID, r.logger.WithFields(log.Fields{
"target": fb.Target, "service_id": fb.ServiceID,
}).Warnf("TCP relay (fallback): %v", err) "target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
}
_ = conn.Close() _ = conn.Close()
} }
return return
@@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
} }
} }
// SetGeo sets the geolocation lookup used for country-based restrictions.
func (r *Router) SetGeo(geo restrict.GeoResolver) {
r.mu.Lock()
defer r.mu.Unlock()
r.geo = geo
}
// checkRestrictions evaluates the route's access filter against the
// connection's remote address. Returns Allow if the connection is
// permitted, or a deny verdict indicating the reason.
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
if route.Filter == nil {
return restrict.Allow
}
addr, err := addrFromConn(conn)
if err != nil {
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
return restrict.DenyCIDR
}
r.mu.RLock()
geo := r.geo
r.mu.RUnlock()
return route.Filter.Check(addr, geo)
}
// relayTCP sets up and runs a bidirectional TCP relay. // relayTCP sets up and runs a bidirectional TCP relay.
// The caller owns conn and must close it if this method returns an error. // The caller owns conn and must close it if this method returns an error.
// On success (nil error), both conn and backend are closed by the relay. // On success (nil error), both conn and backend are closed by the relay.
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error { func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
r.logL4Deny(route, conn, verdict)
return errAccessRestricted
}
svcCtx, err := r.acquireRelay(ctx, route) svcCtx, err := r.acquireRelay(ctx, route)
if err != nil { if err != nil {
return err return err
@@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
}) })
entry.Debug("TCP relay started") entry.Debug("TCP relay started")
idleTimeout := route.SessionIdleTimeout
if idleTimeout <= 0 {
idleTimeout = DefaultIdleTimeout
}
start := time.Now() start := time.Now()
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout) s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
elapsed := time.Since(start) elapsed := time.Since(start)
if obs != nil { if obs != nil {
@@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
return return
} }
var sourceIP netip.Addr sourceIP, _ := addrFromConn(conn)
if remote := conn.RemoteAddr(); remote != nil {
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
}
al.LogL4(accesslog.L4Entry{ al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID, AccountID: route.AccountID,
@@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
}) })
} }
// logL4Deny sends an access log entry for a denied connection.
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
r.mu.RLock()
al := r.accessLog
r.mu.RUnlock()
if al == nil {
return
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
ServiceID: route.ServiceID,
Protocol: route.Protocol,
Host: route.Domain,
SourceIP: sourceIP,
DenyReason: verdict.String(),
})
}
// getOrCreateServiceCtxLocked returns the context for a service, creating one // getOrCreateServiceCtxLocked returns the context for a service, creating one
// if it doesn't exist yet. The context is a child of the server context. // if it doesn't exist yet. The context is a child of the server context.
// Must be called with mu held. // Must be called with mu held.
@@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types
r.svcCancels[svcID] = cancel r.svcCancels[svcID] = cancel
return ctx return ctx
} }
// addrFromConn extracts a netip.Addr from a connection's remote address.
func addrFromConn(conn net.Conn) (netip.Addr, error) {
remote := conn.RemoteAddr()
if remote == nil {
return netip.Addr{}, errors.New("no remote address")
}
ap, err := netip.ParseAddrPort(remote.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
) )
@@ -1668,3 +1669,73 @@ func startEchoPlain(t *testing.T) net.Listener {
return ln return ln
} }
// fakeAddr implements net.Addr with a custom string representation.
type fakeAddr string
func (f fakeAddr) Network() string { return "tcp" }
func (f fakeAddr) String() string { return string(f) }
// fakeConn is a minimal net.Conn with a controllable RemoteAddr.
type fakeConn struct {
net.Conn
remote net.Addr
}
func (f *fakeConn) RemoteAddr() net.Addr { return f.remote }
func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied")
}
func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: nil}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied")
}
func TestCheckRestrictions_AllowedAndDenied(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist")
denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist")
}
func TestCheckRestrictions_NilFilter(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
route := Route{Filter: nil}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything")
}
func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
// net.IPv4() returns a 16-byte v4-in-v6 representation internally.
// The restriction check must Unmap it to match the v4 CIDR.
conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR")
// Explicitly v4-mapped-v6 address string.
conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR")
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
) )
@@ -67,6 +68,8 @@ type Relay struct {
dialTimeout time.Duration dialTimeout time.Duration
sessionTTL time.Duration sessionTTL time.Duration
maxSessions int maxSessions int
filter *restrict.Filter
geo restrict.GeoResolver
mu sync.RWMutex mu sync.RWMutex
sessions map[clientAddr]*session sessions map[clientAddr]*session
@@ -114,6 +117,10 @@ type RelayConfig struct {
SessionTTL time.Duration SessionTTL time.Duration
MaxSessions int MaxSessions int
AccessLog l4Logger AccessLog l4Logger
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
// Geo is the geolocation lookup used for country-based restrictions.
Geo restrict.GeoResolver
} }
// New creates a UDP relay for the given listener and backend target. // New creates a UDP relay for the given listener and backend target.
@@ -146,6 +153,8 @@ func New(parentCtx context.Context, cfg RelayConfig) *Relay {
dialTimeout: dialTimeout, dialTimeout: dialTimeout,
sessionTTL: sessionTTL, sessionTTL: sessionTTL,
maxSessions: maxSessions, maxSessions: maxSessions,
filter: cfg.Filter,
geo: cfg.Geo,
sessions: make(map[clientAddr]*session), sessions: make(map[clientAddr]*session),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() any { New: func() any {
@@ -166,9 +175,18 @@ func (r *Relay) ServiceID() types.ServiceID {
// SetObserver sets the session lifecycle observer. Must be called before Serve. // SetObserver sets the session lifecycle observer. Must be called before Serve.
func (r *Relay) SetObserver(obs SessionObserver) { func (r *Relay) SetObserver(obs SessionObserver) {
r.mu.Lock()
defer r.mu.Unlock()
r.observer = obs r.observer = obs
} }
// getObserver returns the current session lifecycle observer.
func (r *Relay) getObserver() SessionObserver {
r.mu.RLock()
defer r.mu.RUnlock()
return r.observer
}
// Serve starts the relay loop. It blocks until the context is canceled // Serve starts the relay loop. It blocks until the context is canceled
// or the listener is closed. // or the listener is closed.
func (r *Relay) Serve() { func (r *Relay) Serve() {
@@ -209,8 +227,8 @@ func (r *Relay) Serve() {
} }
sess.bytesIn.Add(int64(nw)) sess.bytesIn.Add(int64(nw))
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw) obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
} }
r.bufPool.Put(bufp) r.bufPool.Put(bufp)
} }
@@ -234,6 +252,10 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
return nil, r.ctx.Err() return nil, r.ctx.Err()
} }
if err := r.checkAccessRestrictions(addr); err != nil {
return nil, err
}
r.mu.Lock() r.mu.Lock()
if sess, ok = r.sessions[key]; ok && sess != nil { if sess, ok = r.sessions[key]; ok && sess != nil {
@@ -248,16 +270,16 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
if len(r.sessions) >= r.maxSessions { if len(r.sessions) >= r.maxSessions {
r.mu.Unlock() r.mu.Unlock()
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPSessionRejected(r.accountID) obs.UDPSessionRejected(r.accountID)
} }
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions) return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
} }
if !r.sessLimiter.Allow() { if !r.sessLimiter.Allow() {
r.mu.Unlock() r.mu.Unlock()
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPSessionRejected(r.accountID) obs.UDPSessionRejected(r.accountID)
} }
return nil, fmt.Errorf("session creation rate limited") return nil, fmt.Errorf("session creation rate limited")
} }
@@ -274,8 +296,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
r.mu.Lock() r.mu.Lock()
delete(r.sessions, key) delete(r.sessions, key)
r.mu.Unlock() r.mu.Unlock()
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPSessionDialError(r.accountID) obs.UDPSessionDialError(r.accountID)
} }
return nil, fmt.Errorf("dial backend %s: %w", r.target, err) return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
} }
@@ -293,8 +315,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
r.sessions[key] = sess r.sessions[key] = sess
r.mu.Unlock() r.mu.Unlock()
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPSessionStarted(r.accountID) obs.UDPSessionStarted(r.accountID)
} }
r.sessWg.Go(func() { r.sessWg.Go(func() {
@@ -305,6 +327,21 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
return sess, nil return sess, nil
} }
func (r *Relay) checkAccessRestrictions(addr net.Addr) error {
if r.filter == nil {
return nil
}
clientIP, err := addrFromUDPAddr(addr)
if err != nil {
return fmt.Errorf("parse client address %s for restriction check: %w", addr, err)
}
if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow {
r.logDeny(clientIP, v)
return fmt.Errorf("access restricted for %s", addr)
}
return nil
}
// relayBackendToClient reads packets from the backend and writes them // relayBackendToClient reads packets from the backend and writes them
// back to the client through the public-facing listener. // back to the client through the public-facing listener.
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) { func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
@@ -332,8 +369,8 @@ func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
} }
sess.bytesOut.Add(int64(nw)) sess.bytesOut.Add(int64(nw))
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw) obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
} }
} }
} }
@@ -402,9 +439,10 @@ func (r *Relay) cleanupIdleSessions() {
} }
r.mu.Unlock() r.mu.Unlock()
obs := r.getObserver()
for _, sess := range expired { for _, sess := range expired {
if r.observer != nil { if obs != nil {
r.observer.UDPSessionEnded(r.accountID) obs.UDPSessionEnded(r.accountID)
} }
r.logSessionEnd(sess) r.logSessionEnd(sess)
} }
@@ -429,8 +467,8 @@ func (r *Relay) removeSession(sess *session) {
if removed { if removed {
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)", r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load()) sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
if r.observer != nil { if obs := r.getObserver(); obs != nil {
r.observer.UDPSessionEnded(r.accountID) obs.UDPSessionEnded(r.accountID)
} }
r.logSessionEnd(sess) r.logSessionEnd(sess)
} }
@@ -459,6 +497,22 @@ func (r *Relay) logSessionEnd(sess *session) {
}) })
} }
// logDeny sends an access log entry for a denied UDP packet.
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) {
if r.accessLog == nil {
return
}
r.accessLog.LogL4(accesslog.L4Entry{
AccountID: r.accountID,
ServiceID: r.serviceID,
Protocol: accesslog.ProtocolUDP,
Host: r.domain,
SourceIP: clientIP,
DenyReason: verdict.String(),
})
}
// Close stops the relay, waits for all session goroutines to exit, // Close stops the relay, waits for all session goroutines to exit,
// and cleans up remaining sessions. // and cleans up remaining sessions.
func (r *Relay) Close() { func (r *Relay) Close() {
@@ -485,12 +539,22 @@ func (r *Relay) Close() {
} }
r.mu.Unlock() r.mu.Unlock()
obs := r.getObserver()
for _, sess := range closedSessions { for _, sess := range closedSessions {
if r.observer != nil { if obs != nil {
r.observer.UDPSessionEnded(r.accountID) obs.UDPSessionEnded(r.accountID)
} }
r.logSessionEnd(sess) r.logSessionEnd(sess)
} }
r.sessWg.Wait() r.sessWg.Wait()
} }
// addrFromUDPAddr extracts a netip.Addr from a net.Addr.
func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) {
ap, err := netip.ParseAddrPort(addr.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}

View File

@@ -490,7 +490,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
logger := log.New() logger := log.New()
logger.SetLevel(log.WarnLevel) logger.SetLevel(log.WarnLevel)
authMw := auth.NewMiddleware(logger, nil) authMw := auth.NewMiddleware(logger, nil, nil)
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger) proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
clusterAddress := "test.proxy.io" clusterAddress := "test.proxy.io"
@@ -511,6 +511,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
0, 0,
proxytypes.AccountID(mapping.GetAccountId()), proxytypes.AccountID(mapping.GetAccountId()),
proxytypes.ServiceID(mapping.GetId()), proxytypes.ServiceID(mapping.GetId()),
nil,
) )
require.NoError(t, err) require.NoError(t, err)

View File

@@ -43,12 +43,14 @@ import (
"github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack" "github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug" "github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s" "github.com/netbirdio/netbird/proxy/internal/k8s"
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
@@ -59,7 +61,6 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
) )
// portRouter bundles a per-port Router with its listener and cancel func. // portRouter bundles a per-port Router with its listener and cancel func.
type portRouter struct { type portRouter struct {
router *nbtcp.Router router *nbtcp.Router
@@ -95,6 +96,9 @@ type Server struct {
// so they can be closed during graceful shutdown, since http.Server.Shutdown // so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them. // does not handle them.
hijackTracker conntrack.HijackTracker hijackTracker conntrack.HijackTracker
// geo resolves IP addresses to country/city for access restrictions and access logs.
geo restrict.GeoResolver
geoRaw *geolocation.Lookup
// routerReady is closed once mainRouter is fully initialized. // routerReady is closed once mainRouter is fully initialized.
// The mapping worker waits on this before processing updates. // The mapping worker waits on this before processing updates.
@@ -159,10 +163,38 @@ type Server struct {
// SupportsCustomPorts indicates whether the proxy can bind arbitrary // SupportsCustomPorts indicates whether the proxy can bind arbitrary
// ports for TCP/UDP/TLS services. // ports for TCP/UDP/TLS services.
SupportsCustomPorts bool SupportsCustomPorts bool
// DefaultDialTimeout is the default timeout for establishing backend // MaxDialTimeout caps the per-service backend dial timeout.
// connections when no per-service timeout is configured. Zero means // When the API sends a timeout, it is clamped to this value.
// each transport uses its own hardcoded default (typically 30s). // When the API sends no timeout, this value is used as the default.
DefaultDialTimeout time.Duration // Zero means no cap (the proxy honors whatever management sends).
MaxDialTimeout time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files for
// country-based access restrictions. Empty disables geo lookups.
GeoDataDir string
// MaxSessionIdleTimeout caps the per-service session idle timeout.
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
return s.MaxSessionIdleTimeout
}
return d
}
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
// If d is zero, MaxDialTimeout is used as the default.
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
if s.MaxDialTimeout <= 0 {
return d
}
if d <= 0 || d > s.MaxDialTimeout {
return s.MaxDialTimeout
}
return d
} }
// NotifyStatus sends a status update to management about tunnel connectivity. // NotifyStatus sends a status update to management about tunnel connectivity.
@@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
runCtx, runCancel := context.WithCancel(ctx) runCtx, runCancel := context.WithCancel(ctx)
defer runCancel() defer runCancel()
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
// Initialize the netbird client, this is required to build peer connections // Initialize the netbird client, this is required to build peer connections
// to proxy over. // to proxy over.
@@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
PreSharedKey: s.PreSharedKey, PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient) }, s.Logger, s, s.mgmtClient)
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
tlsConfig, err := s.configureTLS(ctx) tlsConfig, err := s.configureTLS(ctx)
if err != nil { if err != nil {
return err return err
@@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
var startupOK bool
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks. // Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server. // Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.startDebugEndpoint() s.startDebugEndpoint()
if err := s.startHealthServer(); err != nil { if err := s.startHealthServer(); err != nil {
@@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
} }
startupOK = true
httpsErr := make(chan error, 1) httpsErr := make(chan error, 1)
go func() { go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
@@ -691,6 +749,16 @@ func (s *Server) shutdownServices() {
s.portRouterWg.Wait() s.portRouterWg.Wait()
wg.Wait() wg.Wait()
if s.accessLog != nil {
s.accessLog.Close()
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation: %v", err)
}
}
} }
// resolveDialFunc returns a DialContextFunc that dials through the // resolveDialFunc returns a DialContextFunc that dials through the
@@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TCP port %d: %w", port, err) return fmt.Errorf("router for TCP port %d: %w", port, err)
} }
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.SetFallback(nbtcp.Route{ router.SetFallback(nbtcp.Route{
Type: nbtcp.RouteTCP, Type: nbtcp.RouteTCP,
AccountID: accountID, AccountID: accountID,
ServiceID: svcID, ServiceID: svcID,
Domain: mapping.GetDomain(), Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP, Protocol: accesslog.ProtocolTCP,
Target: targetAddr, Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping), ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping), DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
}) })
s.portMu.Lock() s.portMu.Lock()
@@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("empty target address for UDP service %s", svcID) return fmt.Errorf("empty target address for UDP service %s", svcID)
} }
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
return fmt.Errorf("UDP relay for service %s: %w", svcID, err) return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
} }
@@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
} }
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
Type: nbtcp.RouteTCP, Type: nbtcp.RouteTCP,
AccountID: accountID, AccountID: accountID,
ServiceID: svcID, ServiceID: svcID,
Domain: mapping.GetDomain(), Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS, Protocol: accesslog.ProtocolTLS,
Target: targetAddr, Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping), ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping), DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
}) })
if tlsPort != s.mainPort { if tlsPort != s.mainPort {
@@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
} }
} }
// parseRestrictions converts a proto mapping's access restrictions into
// a restrict.Filter. Returns nil if the mapping has no restrictions.
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
r := mapping.GetAccessRestrictions()
if r == nil {
return nil
}
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
}
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
// but the proxy has no geolocation database loaded. All requests to this
// service will be denied at runtime (fail-close).
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
if r == nil {
return
}
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
return
}
if s.geo != nil && s.geo.Available() {
return
}
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
}
// l4TargetAddress extracts and validates the target address from a mapping's // l4TargetAddress extracts and validates the target address from a mapping's
// first path entry. Returns empty string if no paths exist or the address is // first path entry. Returns empty string if no paths exist or the address is
// not a valid host:port. // not a valid host:port.
@@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
} }
// l4DialTimeout returns the dial timeout from the first target's options, // l4DialTimeout returns the dial timeout from the first target's options,
// falling back to the server's DefaultDialTimeout. // clamped to MaxDialTimeout.
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
paths := mapping.GetPath() paths := mapping.GetPath()
if len(paths) > 0 { if len(paths) > 0 {
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
return d.AsDuration() return s.clampDialTimeout(d.AsDuration())
} }
} }
return s.DefaultDialTimeout return s.clampDialTimeout(0)
} }
// l4SessionIdleTimeout returns the configured session idle timeout from the // l4SessionIdleTimeout returns the configured session idle timeout from the
@@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
dialFn, err := s.resolveDialFunc(accountID) dialFn, err := s.resolveDialFunc(accountID)
if err != nil { if err != nil {
_ = listener.Close() if err := listener.Close(); err != nil {
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
}
return fmt.Errorf("resolve dialer for UDP: %w", err) return fmt.Errorf("resolve dialer for UDP: %w", err)
} }
@@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
ServiceID: svcID, ServiceID: svcID,
DialFunc: dialFn, DialFunc: dialFn,
DialTimeout: s.l4DialTimeout(mapping), DialTimeout: s.l4DialTimeout(mapping),
SessionTTL: l4SessionIdleTimeout(mapping), SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
AccessLog: s.accessLog, AccessLog: s.accessLog,
Filter: parseRestrictions(mapping),
Geo: s.geo,
}) })
relay.SetObserver(s.meter) relay.SetObserver(s.meter)
@@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
if mapping.GetAuth().GetOidc() { if mapping.GetAuth().GetOidc() {
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
} }
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
}
ipRestrictions := parseRestrictions(mapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil { if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
} }
m := s.protoToMapping(ctx, mapping) m := s.protoToMapping(ctx, mapping)
@@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
pt.RequestTimeout = d.AsDuration() pt.RequestTimeout = d.AsDuration()
} }
} }
if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 { pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
pt.RequestTimeout = s.DefaultDialTimeout
}
paths[pathMapping.GetPath()] = pt paths[pathMapping.GetPath()] = pt
} }
return proxy.Mapping{ m := proxy.Mapping{
ID: types.ServiceID(mapping.GetId()), ID: types.ServiceID(mapping.GetId()),
AccountID: types.AccountID(mapping.GetAccountId()), AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(), Host: mapping.GetDomain(),
@@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
PassHostHeader: mapping.GetPassHostHeader(), PassHostHeader: mapping.GetPassHostHeader(),
RewriteRedirects: mapping.GetRewriteRedirects(), RewriteRedirects: mapping.GetRewriteRedirects(),
} }
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
}
return m
} }
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode { func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {

View File

@@ -2826,6 +2826,10 @@ components:
type: string type: string
description: "City name from geolocation" description: "City name from geolocation"
example: "San Francisco" example: "San Francisco"
subdivision_code:
type: string
description: "First-level administrative subdivision ISO code (e.g. state/province)"
example: "CA"
bytes_upload: bytes_upload:
type: integer type: integer
format: int64 format: int64
@@ -2952,26 +2956,32 @@ components:
id: id:
type: string type: string
description: Service ID description: Service ID
example: "cs8i4ug6lnn4g9hqv7mg"
name: name:
type: string type: string
description: Service name description: Service name
example: "myapp.example.netbird.app"
domain: domain:
type: string type: string
description: Domain for the service description: Domain for the service
example: "myapp.example.netbird.app"
mode: mode:
type: string type: string
description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough.
enum: [http, tcp, udp, tls] enum: [http, tcp, udp, tls]
default: http default: http
example: "http"
listen_port: listen_port:
type: integer type: integer
minimum: 0 minimum: 0
maximum: 65535 maximum: 65535
description: Port the proxy listens on (L4/TLS only) description: Port the proxy listens on (L4/TLS only)
example: 8443
port_auto_assigned: port_auto_assigned:
type: boolean type: boolean
description: Whether the listen port was auto-assigned description: Whether the listen port was auto-assigned
readOnly: true readOnly: true
example: false
proxy_cluster: proxy_cluster:
type: string type: string
description: The proxy cluster handling this service (derived from domain) description: The proxy cluster handling this service (derived from domain)
@@ -2984,14 +2994,19 @@ components:
enabled: enabled:
type: boolean type: boolean
description: Whether the service is enabled description: Whether the service is enabled
example: true
pass_host_header: pass_host_header:
type: boolean type: boolean
description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
example: false
rewrite_redirects: rewrite_redirects:
type: boolean type: boolean
description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
example: false
auth: auth:
$ref: '#/components/schemas/ServiceAuthConfig' $ref: '#/components/schemas/ServiceAuthConfig'
access_restrictions:
$ref: '#/components/schemas/AccessRestrictions'
meta: meta:
$ref: '#/components/schemas/ServiceMeta' $ref: '#/components/schemas/ServiceMeta'
required: required:
@@ -3035,19 +3050,23 @@ components:
name: name:
type: string type: string
description: Service name description: Service name
example: "myapp.example.netbird.app"
domain: domain:
type: string type: string
description: Domain for the service description: Domain for the service
example: "myapp.example.netbird.app"
mode: mode:
type: string type: string
description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough.
enum: [http, tcp, udp, tls] enum: [http, tcp, udp, tls]
default: http default: http
example: "http"
listen_port: listen_port:
type: integer type: integer
minimum: 0 minimum: 0
maximum: 65535 maximum: 65535
description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment. description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment.
example: 5432
targets: targets:
type: array type: array
items: items:
@@ -3057,14 +3076,19 @@ components:
type: boolean type: boolean
description: Whether the service is enabled description: Whether the service is enabled
default: true default: true
example: true
pass_host_header: pass_host_header:
type: boolean type: boolean
description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
example: false
rewrite_redirects: rewrite_redirects:
type: boolean type: boolean
description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
example: false
auth: auth:
$ref: '#/components/schemas/ServiceAuthConfig' $ref: '#/components/schemas/ServiceAuthConfig'
access_restrictions:
$ref: '#/components/schemas/AccessRestrictions'
required: required:
- name - name
- domain - domain
@@ -3075,13 +3099,16 @@ components:
skip_tls_verify: skip_tls_verify:
type: boolean type: boolean
description: Skip TLS certificate verification for this backend description: Skip TLS certificate verification for this backend
example: false
request_timeout: request_timeout:
type: string type: string
description: Per-target response timeout as a Go duration string (e.g. "30s", "2m") description: Per-target response timeout as a Go duration string (e.g. "30s", "2m")
example: "30s"
path_rewrite: path_rewrite:
type: string type: string
description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
enum: [preserve] enum: [preserve]
example: "preserve"
custom_headers: custom_headers:
type: object type: object
description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected. description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
@@ -3091,40 +3118,50 @@ components:
additionalProperties: additionalProperties:
type: string type: string
pattern: '^[^\r\n]*$' pattern: '^[^\r\n]*$'
example: {"X-Custom-Header": "value"}
proxy_protocol: proxy_protocol:
type: boolean type: boolean
description: Send PROXY Protocol v2 header to this backend (TCP/TLS only) description: Send PROXY Protocol v2 header to this backend (TCP/TLS only)
example: false
session_idle_timeout: session_idle_timeout:
type: string type: string
description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m").
example: "2m"
ServiceTarget: ServiceTarget:
type: object type: object
properties: properties:
target_id: target_id:
type: string type: string
description: Target ID description: Target ID
example: "cs8i4ug6lnn4g9hqv7mg"
target_type: target_type:
type: string type: string
description: Target type description: Target type
enum: [peer, host, domain, subnet] enum: [peer, host, domain, subnet]
example: "subnet"
path: path:
type: string type: string
description: URL path prefix for this target (HTTP only) description: URL path prefix for this target (HTTP only)
example: "/"
protocol: protocol:
type: string type: string
description: Protocol to use when connecting to the backend description: Protocol to use when connecting to the backend
enum: [http, https, tcp, udp] enum: [http, https, tcp, udp]
example: "http"
host: host:
type: string type: string
description: Backend ip or domain for this target description: Backend ip or domain for this target
example: "10.10.0.1"
port: port:
type: integer type: integer
minimum: 1 minimum: 1
maximum: 65535 maximum: 65535
description: Backend port for this target description: Backend port for this target
example: 8080
enabled: enabled:
type: boolean type: boolean
description: Whether this target is enabled description: Whether this target is enabled
example: true
options: options:
$ref: '#/components/schemas/ServiceTargetOptions' $ref: '#/components/schemas/ServiceTargetOptions'
required: required:
@@ -3144,15 +3181,73 @@ components:
$ref: '#/components/schemas/BearerAuthConfig' $ref: '#/components/schemas/BearerAuthConfig'
link_auth: link_auth:
$ref: '#/components/schemas/LinkAuthConfig' $ref: '#/components/schemas/LinkAuthConfig'
header_auths:
type: array
items:
$ref: '#/components/schemas/HeaderAuthConfig'
HeaderAuthConfig:
type: object
description: Static header-value authentication. The proxy checks that the named header matches the configured value.
properties:
enabled:
type: boolean
description: Whether header auth is enabled
example: true
header:
type: string
description: HTTP header name to check (e.g. "Authorization", "X-API-Key")
example: "X-API-Key"
value:
type: string
description: Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses.
example: "my-secret-api-key"
required:
- enabled
- header
- value
AccessRestrictions:
type: object
description: Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
properties:
allowed_cidrs:
type: array
items:
type: string
format: cidr
example: "192.168.1.0/24"
description: CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed.
blocked_cidrs:
type: array
items:
type: string
format: cidr
example: "10.0.0.0/8"
description: CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs.
allowed_countries:
type: array
items:
type: string
pattern: '^[a-zA-Z]{2}$'
example: "US"
description: ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted.
blocked_countries:
type: array
items:
type: string
pattern: '^[a-zA-Z]{2}$'
example: "DE"
description: ISO 3166-1 alpha-2 country codes to block.
PasswordAuthConfig: PasswordAuthConfig:
type: object type: object
properties: properties:
enabled: enabled:
type: boolean type: boolean
description: Whether password auth is enabled description: Whether password auth is enabled
example: true
password: password:
type: string type: string
description: Auth password description: Auth password
example: "s3cret"
required: required:
- enabled - enabled
- password - password
@@ -3162,9 +3257,11 @@ components:
enabled: enabled:
type: boolean type: boolean
description: Whether PIN auth is enabled description: Whether PIN auth is enabled
example: false
pin: pin:
type: string type: string
description: PIN value description: PIN value
example: "1234"
required: required:
- enabled - enabled
- pin - pin
@@ -3174,10 +3271,12 @@ components:
enabled: enabled:
type: boolean type: boolean
description: Whether bearer auth is enabled description: Whether bearer auth is enabled
example: true
distribution_groups: distribution_groups:
type: array type: array
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv7mg"
description: List of group IDs that can use bearer auth description: List of group IDs that can use bearer auth
required: required:
- enabled - enabled
@@ -3187,6 +3286,7 @@ components:
enabled: enabled:
type: boolean type: boolean
description: Whether link auth is enabled description: Whether link auth is enabled
example: false
required: required:
- enabled - enabled
ProxyCluster: ProxyCluster:
@@ -3217,20 +3317,25 @@ components:
id: id:
type: string type: string
description: Domain ID description: Domain ID
example: "ds8i4ug6lnn4g9hqv7mg"
domain: domain:
type: string type: string
description: Domain name description: Domain name
example: "example.netbird.app"
validated: validated:
type: boolean type: boolean
description: Whether the domain has been validated description: Whether the domain has been validated
example: true
type: type:
$ref: '#/components/schemas/ReverseProxyDomainType' $ref: '#/components/schemas/ReverseProxyDomainType'
target_cluster: target_cluster:
type: string type: string
description: The proxy cluster this domain is validated against (only for custom domains) description: The proxy cluster this domain is validated against (only for custom domains)
example: "eu.proxy.netbird.io"
supports_custom_ports: supports_custom_ports:
type: boolean type: boolean
description: Whether the cluster supports binding arbitrary TCP/UDP ports description: Whether the cluster supports binding arbitrary TCP/UDP ports
example: true
required: required:
- id - id
- domain - domain
@@ -3242,9 +3347,11 @@ components:
domain: domain:
type: string type: string
description: Domain name description: Domain name
example: "myapp.example.com"
target_cluster: target_cluster:
type: string type: string
description: The proxy cluster this domain should be validated against description: The proxy cluster this domain should be validated against
example: "eu.proxy.netbird.io"
required: required:
- domain - domain
- target_cluster - target_cluster

View File

@@ -1276,6 +1276,21 @@ func (e PutApiIntegrationsMspTenantsIdInviteJSONBodyValue) Valid() bool {
} }
} }
// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
type AccessRestrictions struct {
// AllowedCidrs CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed.
AllowedCidrs *[]string `json:"allowed_cidrs,omitempty"`
// AllowedCountries ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted.
AllowedCountries *[]string `json:"allowed_countries,omitempty"`
// BlockedCidrs CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs.
BlockedCidrs *[]string `json:"blocked_cidrs,omitempty"`
// BlockedCountries ISO 3166-1 alpha-2 country codes to block.
BlockedCountries *[]string `json:"blocked_countries,omitempty"`
}
// AccessiblePeer defines model for AccessiblePeer. // AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct { type AccessiblePeer struct {
// CityName Commonly used English name of the city // CityName Commonly used English name of the city
@@ -1988,6 +2003,18 @@ type GroupRequest struct {
Resources *[]Resource `json:"resources,omitempty"` Resources *[]Resource `json:"resources,omitempty"`
} }
// HeaderAuthConfig Static header-value authentication. The proxy checks that the named header matches the configured value.
type HeaderAuthConfig struct {
// Enabled Whether header auth is enabled
Enabled bool `json:"enabled"`
// Header HTTP header name to check (e.g. "Authorization", "X-API-Key")
Header string `json:"header"`
// Value Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses.
Value string `json:"value"`
}
// HuntressMatchAttributes Attribute conditions to match when approving agents // HuntressMatchAttributes Attribute conditions to match when approving agents
type HuntressMatchAttributes struct { type HuntressMatchAttributes struct {
// DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus. // DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus.
@@ -3324,6 +3351,9 @@ type ProxyAccessLog struct {
// StatusCode HTTP status code returned // StatusCode HTTP status code returned
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
// SubdivisionCode First-level administrative subdivision ISO code (e.g. state/province)
SubdivisionCode *string `json:"subdivision_code,omitempty"`
// Timestamp Timestamp when the request was made // Timestamp Timestamp when the request was made
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
@@ -3562,7 +3592,9 @@ type SentinelOneMatchAttributesNetworkStatus string
// Service defines model for Service. // Service defines model for Service.
type Service struct { type Service struct {
Auth ServiceAuthConfig `json:"auth"` // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"`
Auth ServiceAuthConfig `json:"auth"`
// Domain Domain for the service // Domain Domain for the service
Domain string `json:"domain"` Domain string `json:"domain"`
@@ -3605,6 +3637,7 @@ type ServiceMode string
// ServiceAuthConfig defines model for ServiceAuthConfig. // ServiceAuthConfig defines model for ServiceAuthConfig.
type ServiceAuthConfig struct { type ServiceAuthConfig struct {
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"` BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"`
HeaderAuths *[]HeaderAuthConfig `json:"header_auths,omitempty"`
LinkAuth *LinkAuthConfig `json:"link_auth,omitempty"` LinkAuth *LinkAuthConfig `json:"link_auth,omitempty"`
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty"` PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty"` PinAuth *PINAuthConfig `json:"pin_auth,omitempty"`
@@ -3627,7 +3660,9 @@ type ServiceMetaStatus string
// ServiceRequest defines model for ServiceRequest. // ServiceRequest defines model for ServiceRequest.
type ServiceRequest struct { type ServiceRequest struct {
Auth *ServiceAuthConfig `json:"auth,omitempty"` // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"`
Auth *ServiceAuthConfig `json:"auth,omitempty"`
// Domain Domain for the service // Domain Domain for the service
Domain string `json:"domain"` Domain string `json:"domain"`
@@ -3702,7 +3737,7 @@ type ServiceTargetOptions struct {
// RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m") // RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m")
RequestTimeout *string `json:"request_timeout,omitempty"` RequestTimeout *string `json:"request_timeout,omitempty"`
// SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m. // SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m").
SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"` SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"`
// SkipTlsVerify Skip TLS certificate verification for this backend // SkipTlsVerify Skip TLS certificate verification for this backend

File diff suppressed because it is too large Load Diff

View File

@@ -80,12 +80,27 @@ message PathMapping {
PathTargetOptions options = 3; PathTargetOptions options = 3;
} }
message HeaderAuth {
// Header name to check, e.g. "Authorization", "X-API-Key".
string header = 1;
// argon2id hash of the expected full header value.
string hashed_value = 2;
}
message Authentication { message Authentication {
string session_key = 1; string session_key = 1;
int64 max_session_age_seconds = 2; int64 max_session_age_seconds = 2;
bool password = 3; bool password = 3;
bool pin = 4; bool pin = 4;
bool oidc = 5; bool oidc = 5;
repeated HeaderAuth header_auths = 6;
}
message AccessRestrictions {
repeated string allowed_cidrs = 1;
repeated string blocked_cidrs = 2;
repeated string allowed_countries = 3;
repeated string blocked_countries = 4;
} }
message ProxyMapping { message ProxyMapping {
@@ -106,6 +121,7 @@ message ProxyMapping {
string mode = 10; string mode = 10;
// For L4/TLS: the port the proxy listens on. // For L4/TLS: the port the proxy listens on.
int32 listen_port = 11; int32 listen_port = 11;
AccessRestrictions access_restrictions = 12;
} }
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy. // SendAccessLogRequest consists of one or more AccessLogs from a Proxy.
@@ -141,9 +157,15 @@ message AuthenticateRequest {
oneof request { oneof request {
PasswordRequest password = 3; PasswordRequest password = 3;
PinRequest pin = 4; PinRequest pin = 4;
HeaderAuthRequest header_auth = 5;
} }
} }
message HeaderAuthRequest {
string header_value = 1;
string header_name = 2;
}
message PasswordRequest { message PasswordRequest {
string password = 1; string password = 1;
} }

View File

@@ -65,8 +65,8 @@ func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool {
} }
entry := earlyMsg{ entry := earlyMsg{
peerID: peerID, peerID: peerID,
msg: msg, msg: msg,
createdAt: time.Now(), createdAt: time.Now(),
} }
elem := b.order.PushBack(entry) elem := b.order.PushBack(entry)