[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -20,22 +20,23 @@ const (
)
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
SubdivisionCode string
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -105,6 +106,11 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
cityName = &a.GeoLocation.CityName
}
var subdivisionCode *string
if a.SubdivisionCode != "" {
subdivisionCode = &a.SubdivisionCode
}
var protocol *string
if a.Protocol != "" {
p := string(a.Protocol)
@@ -112,22 +118,23 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
}
return &api.ProxyAccessLog{
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
Protocol: protocol,
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
SubdivisionCode: subdivisionCode,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
Protocol: protocol,
}
}

View File

@@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
logEntry.GeoLocation.CityName = location.City.Names.En
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
if len(location.Subdivisions) > 0 {
logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"math/rand/v2"
"net/http"
"os"
"slices"
"strconv"
@@ -229,6 +230,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return fmt.Errorf("hash secrets: %w", err)
}
for i, h := range service.Auth.HeaderAuths {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
@@ -488,6 +495,9 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
@@ -544,18 +554,52 @@ func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
svc.Auth.PasswordAuth.Password == "" {
svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
svc.Auth.PinAuth.Pin == "" {
svc.Auth.PinAuth = existingService.Auth.PinAuth
}
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}
// preserveHeaderAuthHashes fills in empty header auth values from the existing
// service so that unchanged secrets are not lost on update.
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
if len(headers) == 0 || len(existing) == 0 {
return
}
existingByHeader := make(map[string]string, len(existing))
for _, h := range existing {
if h != nil && h.Value != "" {
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
}
}
for _, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
h.Value = hash
}
}
}
}
// validateHeaderAuthValues checks that all enabled header auths have a value
// (either freshly provided or preserved from the existing service).
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
for i, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
return nil
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
@@ -605,6 +649,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
}
return nil

View File

@@ -7,14 +7,15 @@ import (
"math/big"
"net"
"net/http"
"net/netip"
"net/url"
"regexp"
"slices"
"strconv"
"strings"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
@@ -91,10 +92,37 @@ type BearerAuthConfig struct {
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
// HeaderAuthConfig defines a static header-value auth check.
// The proxy compares the incoming header value against the stored hash.
type HeaderAuthConfig struct {
Enabled bool `json:"enabled"`
Header string `json:"header"`
Value string `json:"value"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"`
}
// AccessRestrictions controls who can connect to the service based on IP or geography.
type AccessRestrictions struct {
AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"`
BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"`
AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"`
BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"`
}
// Copy returns a deep copy of the AccessRestrictions.
func (r AccessRestrictions) Copy() AccessRestrictions {
return AccessRestrictions{
AllowedCIDRs: slices.Clone(r.AllowedCIDRs),
BlockedCIDRs: slices.Clone(r.BlockedCIDRs),
AllowedCountries: slices.Clone(r.AllowedCountries),
BlockedCountries: slices.Clone(r.BlockedCountries),
}
}
func (a *AuthConfig) HashSecrets() error {
@@ -114,6 +142,16 @@ func (a *AuthConfig) HashSecrets() error {
a.PinAuth.Pin = hashedPin
}
for i, h := range a.HeaderAuths {
if h != nil && h.Enabled && h.Value != "" {
hashedValue, err := argon2id.Hash(h.Value)
if err != nil {
return fmt.Errorf("hash header auth[%d] value: %w", i, err)
}
h.Value = hashedValue
}
}
return nil
}
@@ -124,6 +162,11 @@ func (a *AuthConfig) ClearSecrets() {
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
for _, h := range a.HeaderAuths {
if h != nil {
h.Value = ""
}
}
}
type Meta struct {
@@ -143,12 +186,13 @@ type Service struct {
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
Auth AuthConfig `gorm:"serializer:json"`
Restrictions AccessRestrictions `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
// Mode determines the service type: "http", "tcp", "udp", or "tls".
Mode string `gorm:"default:'http'"`
ListenPort uint16
@@ -188,6 +232,20 @@ func (s *Service) ToAPIResponse() *api.Service {
}
}
if len(s.Auth.HeaderAuths) > 0 {
apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths))
for _, h := range s.Auth.HeaderAuths {
if h == nil {
continue
}
apiHeaders = append(apiHeaders, api.HeaderAuthConfig{
Enabled: h.Enabled,
Header: h.Header,
})
}
authConfig.HeaderAuths = &apiHeaders
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
@@ -222,18 +280,19 @@ func (s *Service) ToAPIResponse() *api.Service {
listenPort := int(s.ListenPort)
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
AccessRestrictions: restrictionsToAPI(s.Restrictions),
Meta: meta,
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
}
if s.ProxyCluster != "" {
@@ -263,7 +322,16 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
auth.Oidc = true
}
return &proto.ProxyMapping{
for _, h := range s.Auth.HeaderAuths {
if h != nil && h.Enabled {
auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{
Header: h.Header,
HashedValue: h.Value,
})
}
}
mapping := &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
@@ -276,6 +344,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec
}
if r := restrictionsToProto(s.Restrictions); r != nil {
mapping.AccessRestrictions = r
}
return mapping
}
// buildPathMappings constructs PathMapping entries from targets.
@@ -334,8 +408,7 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
panic(fmt.Sprintf("unknown operation type: %v", op))
}
}
@@ -477,6 +550,10 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
s.Auth = authFromAPI(req.Auth)
}
if req.AccessRestrictions != nil {
s.Restrictions = restrictionsFromAPI(req.AccessRestrictions)
}
return nil
}
@@ -538,9 +615,70 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
}
auth.BearerAuth = bearerAuth
}
if reqAuth.HeaderAuths != nil {
for _, h := range *reqAuth.HeaderAuths {
auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{
Enabled: h.Enabled,
Header: h.Header,
Value: h.Value,
})
}
}
return auth
}
func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
if r == nil {
return AccessRestrictions{}
}
var res AccessRestrictions
if r.AllowedCidrs != nil {
res.AllowedCIDRs = *r.AllowedCidrs
}
if r.BlockedCidrs != nil {
res.BlockedCIDRs = *r.BlockedCidrs
}
if r.AllowedCountries != nil {
res.AllowedCountries = *r.AllowedCountries
}
if r.BlockedCountries != nil {
res.BlockedCountries = *r.BlockedCountries
}
return res
}
func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
return nil
}
res := &api.AccessRestrictions{}
if len(r.AllowedCIDRs) > 0 {
res.AllowedCidrs = &r.AllowedCIDRs
}
if len(r.BlockedCIDRs) > 0 {
res.BlockedCidrs = &r.BlockedCIDRs
}
if len(r.AllowedCountries) > 0 {
res.AllowedCountries = &r.AllowedCountries
}
if len(r.BlockedCountries) > 0 {
res.BlockedCountries = &r.BlockedCountries
}
return res
}
func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
return nil
}
return &proto.AccessRestrictions{
AllowedCidrs: r.AllowedCIDRs,
BlockedCidrs: r.BlockedCIDRs,
AllowedCountries: r.AllowedCountries,
BlockedCountries: r.BlockedCountries,
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
@@ -557,6 +695,13 @@ func (s *Service) Validate() error {
s.Mode = ModeHTTP
}
if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil {
return err
}
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
switch s.Mode {
case ModeHTTP:
return s.validateHTTPMode()
@@ -657,6 +802,21 @@ func (s *Service) validateL4Target(target *Target) error {
if target.Path != nil && *target.Path != "" && *target.Path != "/" {
return errors.New("path is not supported for L4 services")
}
if target.Options.SessionIdleTimeout < 0 {
return errors.New("session_idle_timeout must be positive for L4 services")
}
if target.Options.RequestTimeout < 0 {
return errors.New("request_timeout must be positive for L4 services")
}
if target.Options.SkipTLSVerify {
return errors.New("skip_tls_verify is not supported for L4 services")
}
if target.Options.PathRewrite != "" {
return errors.New("path_rewrite is not supported for L4 services")
}
if len(target.Options.CustomHeaders) > 0 {
return errors.New("custom_headers is not supported for L4 services")
}
return nil
}
@@ -688,11 +848,9 @@ func IsPortBasedProtocol(mode string) bool {
}
const (
maxRequestTimeout = 5 * time.Minute
maxSessionIdleTimeout = 10 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
@@ -731,22 +889,12 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
if opts.RequestTimeout < 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.SessionIdleTimeout != 0 {
if opts.SessionIdleTimeout <= 0 {
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
}
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
}
if opts.SessionIdleTimeout < 0 {
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
@@ -796,6 +944,93 @@ func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
seen := make(map[string]struct{})
for i, h := range headers {
if h == nil || !h.Enabled {
continue
}
if h.Header == "" {
return fmt.Errorf("header_auths[%d]: header name is required", i)
}
if !httpHeaderNameRe.MatchString(h.Header) {
return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header)
}
canonical := http.CanonicalHeaderKey(h.Header)
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header)
}
if canonical == "Host" {
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
}
if _, dup := seen[canonical]; dup {
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
}
seen[canonical] = struct{}{}
if len(h.Value) > maxHeaderValueLen {
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
}
}
return nil
}
const (
maxCIDREntries = 200
maxCountryEntries = 50
)
// validateAccessRestrictions validates and normalizes access restriction
// entries. Country codes are uppercased in place.
func validateAccessRestrictions(r *AccessRestrictions) error {
if len(r.AllowedCIDRs) > maxCIDREntries {
return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries)
}
if len(r.BlockedCIDRs) > maxCIDREntries {
return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries)
}
if len(r.AllowedCountries) > maxCountryEntries {
return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries)
}
if len(r.BlockedCountries) > maxCountryEntries {
return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries)
}
for i, raw := range r.AllowedCIDRs {
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return fmt.Errorf("allowed_cidrs[%d]: %w", i, err)
}
if prefix != prefix.Masked() {
return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
}
}
for i, raw := range r.BlockedCIDRs {
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return fmt.Errorf("blocked_cidrs[%d]: %w", i, err)
}
if prefix != prefix.Masked() {
return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
}
}
for i, code := range r.AllowedCountries {
if len(code) != 2 {
return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
}
r.AllowedCountries[i] = strings.ToUpper(code)
}
for i, code := range r.BlockedCountries {
if len(code) != 2 {
return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
}
r.BlockedCountries[i] = strings.ToUpper(code)
}
return nil
}
func (s *Service) EventMeta() map[string]any {
meta := map[string]any{
"name": s.Name,
@@ -827,9 +1062,17 @@ func (s *Service) EventMeta() map[string]any {
}
func (s *Service) isAuthEnabled() bool {
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled)
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) {
return true
}
for _, h := range s.Auth.HeaderAuths {
if h != nil && h.Enabled {
return true
}
}
return false
}
func (s *Service) Copy() *Service {
@@ -866,6 +1109,16 @@ func (s *Service) Copy() *Service {
}
authCopy.BearerAuth = &ba
}
if len(s.Auth.HeaderAuths) > 0 {
authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths))
for i, h := range s.Auth.HeaderAuths {
if h == nil {
continue
}
hCopy := *h
authCopy.HeaderAuths[i] = &hCopy
}
}
return &Service{
ID: s.ID,
@@ -878,6 +1131,7 @@ func (s *Service) Copy() *Service {
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: authCopy,
Restrictions: s.Restrictions.Copy(),
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,

View File

@@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"valid 10m", 10 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
@@ -493,16 +494,17 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
// should be set on the copy.
func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
return &proto.ProxyMapping{
Type: m.Type,
Id: m.Id,
AccountId: m.AccountId,
Domain: m.Domain,
Path: m.Path,
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode,
ListenPort: m.ListenPort,
Type: m.Type,
Id: m.Id,
AccountId: m.AccountId,
Domain: m.Domain,
Path: m.Path,
Auth: m.Auth,
PassHostHeader: m.PassHostHeader,
RewriteRedirects: m.RewriteRedirects,
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
}
}
@@ -561,6 +563,8 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
case *proto.AuthenticateRequest_Password:
return s.authenticatePassword(ctx, req.GetId(), v, service.Auth.PasswordAuth)
case *proto.AuthenticateRequest_HeaderAuth:
return s.authenticateHeader(ctx, req.GetId(), v, service.Auth.HeaderAuths)
default:
return false, "", ""
}
@@ -594,6 +598,35 @@ func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID
return true, "password-user", proxyauth.MethodPassword
}
func (s *ProxyServiceServer) authenticateHeader(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_HeaderAuth, auths []*rpservice.HeaderAuthConfig) (bool, string, proxyauth.Method) {
if len(auths) == 0 {
log.WithContext(ctx).Debugf("header authentication attempted but no header auths configured for service %s", serviceID)
return false, "", ""
}
headerName := http.CanonicalHeaderKey(req.HeaderAuth.GetHeaderName())
var lastErr error
for _, auth := range auths {
if auth == nil || !auth.Enabled {
continue
}
if headerName != "" && http.CanonicalHeaderKey(auth.Header) != headerName {
continue
}
if err := argon2id.Verify(req.HeaderAuth.GetHeaderValue(), auth.Value); err != nil {
lastErr = err
continue
}
return true, "header-user", proxyauth.MethodHeader
}
if lastErr != nil {
s.logAuthenticationError(ctx, lastErr, "Header")
}
return false, "", ""
}
func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err error, authType string) {
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
log.WithContext(ctx).Tracef("%s authentication failed: invalid credentials", authType)
@@ -752,6 +785,9 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
}
if redirectURL.Scheme != "https" && redirectURL.Scheme != "http" {
return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme")
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
@@ -836,12 +872,9 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid.
// The HMAC is verified before consuming the PKCE verifier to prevent
// an attacker from invalidating a legitimate user's auth flow.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")
if len(parts) != 3 {
@@ -865,6 +898,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return "", "", errors.New("invalid state signature")
}
// Consume the PKCE verifier only after HMAC validation passes.
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
return verifier, redirectURL, nil
}

View File

@@ -44,6 +44,12 @@ type Record struct {
GeonameID uint `maxminddb:"geoname_id"`
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
}
type City struct {

View File

@@ -10,7 +10,7 @@ FROM gcr.io/distroless/base:debug
COPY netbird-proxy /go/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
USER netbird:netbird
ENV HOME=/var/lib/netbird

View File

@@ -28,7 +28,7 @@ FROM gcr.io/distroless/base:debug
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
USER netbird:netbird
ENV HOME=/var/lib/netbird

View File

@@ -13,10 +13,11 @@ import (
type Method string
var (
const (
MethodPassword Method = "password"
MethodPIN Method = "pin"
MethodOIDC Method = "oidc"
MethodHeader Method = "header"
)
func (m Method) String() string {

View File

@@ -36,31 +36,33 @@ var (
var (
logLevel string
debugLogs bool
mgmtAddr string
addr string
proxyDomain string
defaultDialTimeout time.Duration
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wildcardCertDir string
wgPort uint16
proxyProtocol bool
preSharedKey string
supportsCustomPorts bool
debugLogs bool
mgmtAddr string
addr string
proxyDomain string
maxDialTimeout time.Duration
maxSessionIdleTimeout time.Duration
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wildcardCertDir string
wgPort uint16
proxyProtocol bool
preSharedKey string
supportsCustomPorts bool
geoDataDir string
)
var rootCmd = &cobra.Command{
@@ -99,7 +101,9 @@ func init() {
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)")
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
}
// Execute runs the root command.
@@ -177,17 +181,15 @@ func runServer(cmd *cobra.Command, args []string) error {
ProxyProtocol: proxyProtocol,
PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts,
DefaultDialTimeout: defaultDialTimeout,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
GeoDataDir: geoDataDir,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil {
logger.Error(err)
return err
}
return nil
return srv.ListenAndServe(ctx, addr)
}
func envBoolOrDefault(key string, def bool) bool {
@@ -197,6 +199,7 @@ func envBoolOrDefault(key string, def bool) bool {
}
parsed, err := strconv.ParseBool(v)
if err != nil {
log.Warnf("parse %s=%q: %v, using default %v", key, v, err, def)
return def
}
return parsed
@@ -217,6 +220,7 @@ func envUint16OrDefault(key string, def uint16) uint16 {
}
parsed, err := strconv.ParseUint(v, 10, 16)
if err != nil {
log.Warnf("parse %s=%q: %v, using default %d", key, v, err, def)
return def
}
return uint16(parsed)
@@ -229,6 +233,7 @@ func envDurationOrDefault(key string, def time.Duration) time.Duration {
}
parsed, err := time.ParseDuration(v)
if err != nil {
log.Warnf("parse %s=%q: %v, using default %s", key, v, err, def)
return def
}
return parsed

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
@@ -22,6 +23,16 @@ const (
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
logSendTimeout = 10 * time.Second
// denyCooldown is the min interval between deny log entries per service+reason
// to prevent flooding from denied connections (e.g. UDP packets from blocked IPs).
denyCooldown = 10 * time.Second
// maxDenyBuckets caps tracked deny rate-limit entries to bound memory under DDoS.
maxDenyBuckets = 10000
// maxLogWorkers caps concurrent gRPC send goroutines.
maxLogWorkers = 4096
)
type domainUsage struct {
@@ -38,6 +49,18 @@ type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
}
// denyBucketKey identifies a rate-limited deny log stream.
type denyBucketKey struct {
ServiceID types.ServiceID
Reason string
}
// denyBucket tracks rate-limited deny log entries.
type denyBucket struct {
lastLogged time.Time
suppressed int64
}
// Logger sends access log entries to the management server via gRPC.
type Logger struct {
client gRPCClient
@@ -47,7 +70,12 @@ type Logger struct {
usageMux sync.Mutex
domainUsage map[string]*domainUsage
denyMu sync.Mutex
denyBuckets map[denyBucketKey]*denyBucket
logSem chan struct{}
cleanupCancel context.CancelFunc
dropped atomic.Int64
}
// NewLogger creates a new access log Logger. The trustedProxies parameter
@@ -64,6 +92,8 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
logger: logger,
trustedProxies: trustedProxies,
domainUsage: make(map[string]*domainUsage),
denyBuckets: make(map[denyBucketKey]*denyBucket),
logSem: make(chan struct{}, maxLogWorkers),
cleanupCancel: cancel,
}
@@ -83,7 +113,7 @@ func (l *Logger) Close() {
type logEntry struct {
ID string
AccountID types.AccountID
ServiceId types.ServiceID
ServiceID types.ServiceID
Host string
Path string
DurationMs int64
@@ -91,7 +121,7 @@ type logEntry struct {
ResponseCode int32
SourceIP netip.Addr
AuthMechanism string
UserId string
UserID string
AuthSuccess bool
BytesUpload int64
BytesDownload int64
@@ -118,6 +148,10 @@ type L4Entry struct {
DurationMs int64
BytesUpload int64
BytesDownload int64
// DenyReason, when non-empty, indicates the connection was denied.
// Values match the HTTP auth mechanism strings: "ip_restricted",
// "country_restricted", "geo_unavailable".
DenyReason string
}
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
@@ -126,7 +160,7 @@ func (l *Logger) LogL4(entry L4Entry) {
le := logEntry{
ID: xid.New().String(),
AccountID: entry.AccountID,
ServiceId: entry.ServiceID,
ServiceID: entry.ServiceID,
Protocol: entry.Protocol,
Host: entry.Host,
SourceIP: entry.SourceIP,
@@ -134,10 +168,47 @@ func (l *Logger) LogL4(entry L4Entry) {
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
}
if entry.DenyReason != "" {
if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) {
return
}
le.AuthMechanism = entry.DenyReason
le.AuthSuccess = false
}
l.log(le)
l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload)
}
// allowDenyLog rate-limits deny log entries per service+reason combination.
func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool {
key := denyBucketKey{ServiceID: serviceID, Reason: reason}
now := time.Now()
l.denyMu.Lock()
defer l.denyMu.Unlock()
b, ok := l.denyBuckets[key]
if !ok {
if len(l.denyBuckets) >= maxDenyBuckets {
return false
}
l.denyBuckets[key] = &denyBucket{lastLogged: now}
return true
}
if now.Sub(b.lastLogged) >= denyCooldown {
if b.suppressed > 0 {
l.logger.Debugf("access restriction: suppressed %d deny log entries for %s (%s)", b.suppressed, serviceID, reason)
}
b.lastLogged = now
b.suppressed = 0
return true
}
b.suppressed++
return false
}
func (l *Logger) log(entry logEntry) {
// Fire off the log request in a separate routine.
// This increases the possibility of losing a log message
@@ -147,12 +218,21 @@ func (l *Logger) log(entry logEntry) {
// There is also a chance that log messages will arrive at
// the server out of order; however, the timestamp should
// allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
now := timestamppb.Now()
select {
case l.logSem <- struct{}{}:
default:
total := l.dropped.Add(1)
l.logger.Debugf("access log send dropped: worker limit reached (total dropped: %d)", total)
return
}
go func() {
defer func() { <-l.logSem }()
logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout)
defer cancel()
// Only OIDC sessions have a meaningful user identity.
if entry.AuthMechanism != auth.MethodOIDC.String() {
entry.UserId = ""
entry.UserID = ""
}
var sourceIP string
@@ -165,7 +245,7 @@ func (l *Logger) log(entry logEntry) {
LogId: entry.ID,
AccountId: string(entry.AccountID),
Timestamp: now,
ServiceId: string(entry.ServiceId),
ServiceId: string(entry.ServiceID),
Host: entry.Host,
Path: entry.Path,
DurationMs: entry.DurationMs,
@@ -173,7 +253,7 @@ func (l *Logger) log(entry logEntry) {
ResponseCode: entry.ResponseCode,
SourceIp: sourceIP,
AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId,
UserId: entry.UserID,
AuthSuccess: entry.AuthSuccess,
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
@@ -181,7 +261,7 @@ func (l *Logger) log(entry logEntry) {
},
}); err != nil {
l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId,
"service_id": entry.ServiceID,
"host": entry.Host,
"path": entry.Path,
"duration": entry.DurationMs,
@@ -189,7 +269,7 @@ func (l *Logger) log(entry logEntry) {
"response_code": entry.ResponseCode,
"source_ip": sourceIP,
"auth_mechanism": entry.AuthMechanism,
"user_id": entry.UserId,
"user_id": entry.UserID,
"auth_success": entry.AuthSuccess,
"error": err,
}).Error("Error sending access log on gRPC connection")
@@ -248,7 +328,7 @@ func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
}
}
// cleanupStaleUsage removes usage entries for domains that have been inactive.
// cleanupStaleUsage removes usage and deny-rate-limit entries that have been inactive.
func (l *Logger) cleanupStaleUsage(ctx context.Context) {
ticker := time.NewTicker(usageCleanupPeriod)
defer ticker.Stop()
@@ -258,20 +338,41 @@ func (l *Logger) cleanupStaleUsage(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
l.usageMux.Lock()
now := time.Now()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
l.usageMux.Unlock()
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
l.cleanupDomainUsage(now)
l.cleanupDenyBuckets(now)
}
}
}
func (l *Logger) cleanupDomainUsage(now time.Time) {
l.usageMux.Lock()
defer l.usageMux.Unlock()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
}
func (l *Logger) cleanupDenyBuckets(now time.Time) {
l.denyMu.Lock()
defer l.denyMu.Unlock()
removed := 0
for key, bucket := range l.denyBuckets {
if now.Sub(bucket.lastLogged) > usageInactiveWindow {
delete(l.denyBuckets, key)
removed++
}
}
if removed > 0 {
l.logger.Debugf("cleaned up %d stale deny rate-limit entries", removed)
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/proxy/web"
)
// Middleware wraps an HTTP handler to log access entries and resolve client IPs.
func (l *Logger) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip logging for internal proxy assets (CSS, JS, etc.)
@@ -47,8 +48,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
// Create a mutable struct to capture data from downstream handlers.
// We pass a pointer in the context - the pointer itself flows down immutably,
// but the struct it points to can be mutated by inner handlers.
capturedData := &proxy.CapturedData{RequestID: requestID}
capturedData := proxy.NewCapturedData(requestID)
capturedData.SetClientIP(sourceIp)
ctx := proxy.WithCapturedData(r.Context(), capturedData)
start := time.Now()
@@ -66,8 +68,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
entry := logEntry{
ID: requestID,
ServiceId: capturedData.GetServiceId(),
AccountID: capturedData.GetAccountId(),
ServiceID: capturedData.GetServiceID(),
AccountID: capturedData.GetAccountID(),
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
@@ -75,14 +77,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
ResponseCode: int32(sw.status),
SourceIP: sourceIp,
AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(),
UserID: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
BytesUpload: bytesUpload,
BytesDownload: bytesDownload,
Protocol: ProtocolHTTP,
}
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID())
l.log(entry)

View File

@@ -0,0 +1,69 @@
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ErrHeaderAuthFailed indicates that the header was present but the
// credential did not validate. Callers should return 401 instead of
// falling through to other auth schemes.
var ErrHeaderAuthFailed = errors.New("header authentication failed")
// Header implements header-based authentication. The proxy checks for the
// configured header in each request and validates its value via gRPC.
type Header struct {
id types.ServiceID
accountId types.AccountID
headerName string
client authenticator
}
// NewHeader creates a Header authentication scheme for the given header name.
func NewHeader(client authenticator, id types.ServiceID, accountId types.AccountID, headerName string) Header {
return Header{
id: id,
accountId: accountId,
headerName: headerName,
client: client,
}
}
// Type returns auth.MethodHeader.
func (Header) Type() auth.Method {
return auth.MethodHeader
}
// Authenticate checks for the configured header in the request. If absent,
// returns empty (unauthenticated). If present, validates via gRPC.
func (h Header) Authenticate(r *http.Request) (string, string, error) {
value := r.Header.Get(h.headerName)
if value == "" {
return "", "", nil
}
res, err := h.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: string(h.id),
AccountId: string(h.accountId),
Request: &proto.AuthenticateRequest_HeaderAuth{
HeaderAuth: &proto.HeaderAuthRequest{
HeaderValue: value,
HeaderName: h.headerName,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate header: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", "", ErrHeaderAuthFailed
}

View File

@@ -4,9 +4,12 @@ import (
"context"
"crypto/ed25519"
"encoding/base64"
"errors"
"fmt"
"html"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
@@ -16,11 +19,16 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto"
)
// errValidationUnavailable indicates that session validation failed due to
// an infrastructure error (e.g. gRPC unavailable), not an invalid token.
var errValidationUnavailable = errors.New("session validation unavailable")
type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
@@ -40,12 +48,14 @@ type Scheme interface {
Authenticate(*http.Request) (token string, promptData string, err error)
}
// DomainConfig holds the authentication and restriction settings for a protected domain.
type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
AccountID types.AccountID
ServiceID types.ServiceID
IPRestrictions *restrict.Filter
}
type validationResult struct {
@@ -54,17 +64,18 @@ type validationResult struct {
DeniedReason string
}
// Middleware applies per-domain authentication and IP restriction checks.
type Middleware struct {
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
sessionValidator SessionValidator
geo restrict.GeoResolver
}
// NewMiddleware creates a new authentication middleware.
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
// locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
// NewMiddleware creates a new authentication middleware. The sessionValidator is
// optional; if nil, OIDC session tokens are validated locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware {
if logger == nil {
logger = log.StandardLogger()
}
@@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl
domains: make(map[string]DomainConfig),
logger: logger,
sessionValidator: sessionValidator,
geo: geo,
}
}
// Protect applies authentication middleware to the passed handler.
// For each incoming request it will be checked against the middleware's
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
// Protect wraps next with per-domain authentication and IP restriction checks.
// Requests whose Host is not registered pass through unchanged.
func (mw *Middleware) Protect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
@@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(config.Schemes) == 0 {
if !exists {
next.ServeHTTP(w, r)
return
}
@@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
// Set account and service IDs in captured data for access logging.
setCapturedIDs(r, config)
if !mw.checkIPRestrictions(w, r, config) {
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
if mw.handleOAuthCallbackError(w, r) {
return
}
@@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
if mw.forwardWithHeaderAuth(w, r, host, config, next) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
@@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(config.AccountID)
cd.SetServiceId(config.ServiceID)
cd.SetAccountID(config.AccountID)
cd.SetServiceID(config.ServiceID)
}
}
// checkIPRestrictions validates the client IP against the domain's IP restrictions.
// Uses the resolved client IP from CapturedData (which accounts for trusted proxies)
// rather than r.RemoteAddr directly.
func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if config.IPRestrictions == nil {
return true
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
if verdict == restrict.Allow {
return true
}
reason := verdict.String()
mw.blockIPRestriction(r, reason)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr.
func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
if ip := cd.GetClientIP(); ip.IsValid() {
return ip
}
}
clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr)
if clientIPStr == "" {
clientIPStr = r.RemoteAddr
}
addr, err := netip.ParseAddr(clientIPStr)
if err != nil {
return netip.Addr{}
}
return addr.Unmap()
}
// blockIPRestriction sets captured data fields for an IP-restriction block event.
func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(reason)
}
mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr)
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
@@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
} else {
errDesc = html.EscapeString(errDesc)
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
@@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
return true
}
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
for _, scheme := range config.Schemes {
hdr, ok := scheme.(Header)
if !ok {
continue
}
handled := mw.tryHeaderScheme(w, r, host, config, hdr, next)
if handled {
return true
}
}
return false
}
func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool {
token, _, err := hdr.Authenticate(r)
if err != nil {
return mw.handleHeaderAuthError(w, r, err)
}
if token == "" {
return false
}
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "")
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return true
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetAuthMethod(auth.MethodHeader.String())
}
next.ServeHTTP(w, r)
return true
}
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return true
}
func setHeaderCapturedData(ctx context.Context, userID string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
}
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
@@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return
}
@@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
return
}
expiration := config.SessionExpiration
setSessionCookie(w, token, config.SessionExpiration)
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// setSessionCookie writes a session cookie with secure defaults.
func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) {
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
@@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
@@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
AccountID: accountID,
ServiceID: serviceID,
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
}
return nil
}
@@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
SessionExpiration: expiration,
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
}
return nil
}
// RemoveDomain unregisters authentication for the given domain.
func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
// validateSessionToken validates a session token, optionally checking group access via gRPC.
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
// validateSessionToken validates a session token. OIDC tokens with a configured
// validator go through gRPC for group access checks; other methods validate locally.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host,
SessionToken: token,
})
if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
return nil, fmt.Errorf("session validation failed")
return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err)
}
if !resp.Valid {
mw.logger.WithFields(log.Fields{
@@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
return &validationResult{UserID: resp.UserId, Valid: true}, nil
}
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err

View File

@@ -1,11 +1,14 @@
package auth
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"testing"
@@ -14,10 +17,13 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/shared/management/proto"
)
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
@@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler {
}
func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
require.NoError(t, err)
mw.domainsMux.RLock()
@@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) {
}
func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) {
}
func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
@@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
}
func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
}
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
@@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
}
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
@@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
}
func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
mw.RemoveDomain("example.com")
@@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) {
}
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
@@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
}
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
}
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
}
func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
}
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd)
@@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
@@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
@@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
@@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
}
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
@@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
}
func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
@@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
}
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
// Return a garbage token that won't validate.
@@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
}
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
// 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise.
@@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
require.Error(t, err)
// The original valid config should still be intact.
@@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
}
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
// Scheme that always fails authentication (returns empty token)
@@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
// Submit wrong PIN - should capture auth method
@@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
}
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
// Submit wrong password - should capture auth method
@@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
}
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
// No credentials submitted - should not capture auth method
@@ -658,3 +664,271 @@ func TestWasCredentialSubmitted(t *testing.T) {
})
}
}
func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
remoteAddr string
wantCode int
}{
{"unparsable address denies", "not-an-ip:1234", http.StatusForbidden},
{"empty address denies", "", http.StatusForbidden},
{"allowed address passes", "10.1.2.3:5678", http.StatusOK},
{"denied address blocked", "192.168.1.1:5678", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = tt.remoteAddr
req.Host = "example.com"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
}
func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
// When CapturedData is set (by the access log middleware, which resolves
// trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr.
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// RemoteAddr is a trusted proxy, but CapturedData has the real client IP.
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "10.0.0.1:5000"
req.Host = "example.com"
cd := proxy.NewCapturedData("")
cd.SetClientIP(netip.MustParseAddr("203.0.113.50"))
ctx := proxy.WithCapturedData(req.Context(), cd)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)")
// Same request but CapturedData has a blocked IP.
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req2.RemoteAddr = "203.0.113.50:5000"
req2.Host = "example.com"
cd2 := proxy.NewCapturedData("")
cd2.SetClientIP(netip.MustParseAddr("10.0.0.1"))
ctx2 := proxy.WithCapturedData(req2.Context(), cd2)
req2 = req2.WithContext(ctx2)
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)")
}
func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
// Geo is nil, country restrictions are configured: must deny (fail-close).
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(nil, nil, []string{"US"}, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "1.2.3.4:5678"
req.Host = "example.com"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
}
// mockAuthenticator is a minimal mock for the authenticator gRPC interface
// used by the Header scheme.
type mockAuthenticator struct {
fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error)
}
func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) {
return m.fn(ctx, in)
}
// newHeaderSchemeWithToken creates a Header scheme backed by a mock that
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && ha.GetHeaderValue() == expectedValue {
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
return &proto.AuthenticateResponse{Success: false}, nil
}}
return NewHeader(mock, "svc1", "acc1", headerName)
}
func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)
req.Header.Set("X-API-Key", "secret-key")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "ok", rec.Body.String())
// Session cookie should be set.
var sessionCookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth")
assert.True(t, sessionCookie.HttpOnly)
assert.True(t, sessionCookie.Secure)
assert.Equal(t, "header-user", capturedData.GetUserID())
assert.Equal(t, "header", capturedData.GetAuthMethod())
}
func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
// No X-API-Key header: should fall through to PIN login page (401).
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page")
}
func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("X-API-Key", "wrong-key")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "header", capturedData.GetAuthMethod())
}
func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("X-API-Key", "some-key")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadGateway, rec.Code)
}
func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request with header auth.
req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req1.Header.Set("X-API-Key", "secret-key")
req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData("")))
rec1 := httptest.NewRecorder()
handler.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
// Extract session cookie.
var sessionCookie *http.Cookie
for _, c := range rec1.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie)
// Second request with only the session cookie (no header).
capturedData2 := proxy.NewCapturedData("")
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil)
req2.AddCookie(sessionCookie)
req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2))
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
assert.Equal(t, http.StatusOK, rec2.Code)
assert.Equal(t, "header-user", capturedData2.GetUserID())
assert.Equal(t, "header", capturedData2.GetAuthMethod())
}

View File

@@ -0,0 +1,264 @@
package geolocation
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
mmdbInnerName = "GeoLite2-City.mmdb"
downloadTimeout = 2 * time.Minute
maxMMDBSize = 256 << 20 // 256 MB
)
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
// it downloads from pkgs.netbird.io with SHA256 verification.
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
if err := os.MkdirAll(dataDir, 0o755); err != nil {
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
}
pattern := filepath.Join(dataDir, mmdbGlob)
if files, _ := filepath.Glob(pattern); len(files) > 0 {
mmdbPath := files[len(files)-1]
logger.Debugf("using existing geolocation database: %s", mmdbPath)
return mmdbPath, nil
}
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
return downloadMMDB(logger, dataDir)
}
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
client := &http.Client{Timeout: downloadTimeout}
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
if err != nil {
return "", fmt.Errorf("get remote filename: %w", err)
}
mmdbFilename := deriveMMDBFilename(datedName)
mmdbPath := filepath.Join(dataDir, mmdbFilename)
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
if err != nil {
return "", fmt.Errorf("create temp directory: %w", err)
}
defer os.RemoveAll(tmp)
checksumFile := filepath.Join(tmp, "checksum.sha256")
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
return "", fmt.Errorf("download checksum: %w", err)
}
expectedHash, err := readChecksumFile(checksumFile)
if err != nil {
return "", fmt.Errorf("read checksum: %w", err)
}
tarFile := filepath.Join(tmp, datedName)
logger.Debugf("downloading geolocation database (%s)", datedName)
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
return "", fmt.Errorf("download database: %w", err)
}
if err := verifySHA256(tarFile, expectedHash); err != nil {
return "", fmt.Errorf("verify database checksum: %w", err)
}
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
return "", fmt.Errorf("extract database: %w", err)
}
logger.Infof("geolocation database downloaded: %s", mmdbPath)
return mmdbPath, nil
}
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
func deriveMMDBFilename(tarName string) string {
base, _, _ := strings.Cut(tarName, ".")
if !strings.Contains(base, "_") {
return "GeoLite2-City.mmdb"
}
return base + ".mmdb"
}
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
resp, err := client.Head(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
}
cd := resp.Header.Get("Content-Disposition")
if cd == "" {
return "", errors.New("no Content-Disposition header")
}
_, params, err := mime.ParseMediaType(cd)
if err != nil {
return "", fmt.Errorf("parse Content-Disposition: %w", err)
}
name := filepath.Base(params["filename"])
if name == "" || name == "." {
return "", errors.New("no filename in Content-Disposition")
}
return name, nil
}
func downloadToFile(client *http.Client, url, dest string) error {
resp, err := client.Get(url) //nolint:gosec
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
f, err := os.Create(dest) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
return err
}
return nil
}
func readChecksumFile(path string) (string, error) {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", errors.New("empty checksum file")
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actual := fmt.Sprintf("%x", h.Sum(nil))
if actual != expected {
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
}
return nil
}
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
f, err := os.Open(tarGZPath) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gz.Close()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
}
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
return err
}
return nil
}
}
return fmt.Errorf("%s not found in archive", mmdbInnerName)
}
// extractToFileAtomic writes r to a temporary file in the same directory as
// destPath, then renames it into place so a crash never leaves a truncated file.
func extractToFileAtomic(r io.Reader, destPath string) error {
dir := filepath.Dir(destPath)
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmp.Name()
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
if closeErr := tmp.Close(); closeErr != nil {
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
}
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("write mmdb: %w", err)
}
if err := tmp.Close(); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("close temp file: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("rename to %s: %w", destPath, err)
}
return nil
}

View File

@@ -0,0 +1,152 @@
// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases.
package geolocation
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"github.com/oschwald/maxminddb-golang"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables geolocation lookups entirely when set to a truthy value.
EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION"
mmdbGlob = "GeoLite2-City_*.mmdb"
)
type record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
City struct {
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"city"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
}
// Result holds the outcome of a geo lookup.
type Result struct {
CountryCode string
CityName string
SubdivisionCode string
SubdivisionName string
}
// Lookup provides IP geolocation lookups.
type Lookup struct {
mu sync.RWMutex
db *maxminddb.Reader
logger *log.Logger
}
// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir.
// Returns nil without error if geolocation is disabled via environment
// variable, no data directory is configured, or the download fails
// (graceful degradation: country restrictions will deny all requests).
func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) {
if isDisabledByEnv(logger) {
logger.Info("geolocation disabled via environment variable")
return nil, nil //nolint:nilnil
}
if dataDir == "" {
return nil, nil //nolint:nilnil
}
mmdbPath, err := ensureMMDB(logger, dataDir)
if err != nil {
logger.Warnf("geolocation database unavailable: %v", err)
logger.Warn("country-based access restrictions will deny all requests until a database is available")
return nil, nil //nolint:nilnil
}
db, err := maxminddb.Open(mmdbPath)
if err != nil {
return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err)
}
logger.Infof("geolocation database loaded from %s", mmdbPath)
return &Lookup{db: db, logger: logger}, nil
}
// LookupAddr returns the country ISO code and city name for the given IP.
// Returns an empty Result if the database is nil or the lookup fails.
func (l *Lookup) LookupAddr(addr netip.Addr) Result {
if l == nil {
return Result{}
}
l.mu.RLock()
defer l.mu.RUnlock()
if l.db == nil {
return Result{}
}
addr = addr.Unmap()
var rec record
if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil {
l.logger.Debugf("geolocation lookup %s: %v", addr, err)
return Result{}
}
r := Result{
CountryCode: rec.Country.ISOCode,
CityName: rec.City.Names.En,
}
if len(rec.Subdivisions) > 0 {
r.SubdivisionCode = rec.Subdivisions[0].ISOCode
r.SubdivisionName = rec.Subdivisions[0].Names.En
}
return r
}
// Available reports whether the lookup has a loaded database.
func (l *Lookup) Available() bool {
if l == nil {
return false
}
l.mu.RLock()
defer l.mu.RUnlock()
return l.db != nil
}
// Close releases the database resources.
func (l *Lookup) Close() error {
if l == nil {
return nil
}
l.mu.Lock()
defer l.mu.Unlock()
if l.db != nil {
err := l.db.Close()
l.db = nil
return err
}
return nil
}
func isDisabledByEnv(logger *log.Logger) bool {
val := os.Getenv(EnvDisable)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
logger.Warnf("parse %s=%q: %v", EnvDisable, val, err)
return false
}
return disabled
}

View File

@@ -11,8 +11,6 @@ import (
type requestContextKey string
const (
serviceIdKey requestContextKey = "serviceId"
accountIdKey requestContextKey = "accountId"
capturedDataKey requestContextKey = "capturedData"
)
@@ -47,112 +45,117 @@ func (o ResponseOrigin) String() string {
// to pass data back up the middleware chain.
type CapturedData struct {
mu sync.RWMutex
RequestID string
ServiceId types.ServiceID
AccountId types.AccountID
Origin ResponseOrigin
ClientIP netip.Addr
UserID string
AuthMethod string
requestID string
serviceID types.ServiceID
accountID types.AccountID
origin ResponseOrigin
clientIP netip.Addr
userID string
authMethod string
}
// GetRequestID safely gets the request ID
// NewCapturedData creates a CapturedData with the given request ID.
func NewCapturedData(requestID string) *CapturedData {
return &CapturedData{requestID: requestID}
}
// GetRequestID returns the request ID.
func (c *CapturedData) GetRequestID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.RequestID
return c.requestID
}
// SetServiceId safely sets the service ID
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
// SetServiceID sets the service ID.
func (c *CapturedData) SetServiceID(serviceID types.ServiceID) {
c.mu.Lock()
defer c.mu.Unlock()
c.ServiceId = serviceId
c.serviceID = serviceID
}
// GetServiceId safely gets the service ID
func (c *CapturedData) GetServiceId() types.ServiceID {
// GetServiceID returns the service ID.
func (c *CapturedData) GetServiceID() types.ServiceID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ServiceId
return c.serviceID
}
// SetAccountId safely sets the account ID
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
// SetAccountID sets the account ID.
func (c *CapturedData) SetAccountID(accountID types.AccountID) {
c.mu.Lock()
defer c.mu.Unlock()
c.AccountId = accountId
c.accountID = accountID
}
// GetAccountId safely gets the account ID
func (c *CapturedData) GetAccountId() types.AccountID {
// GetAccountID returns the account ID.
func (c *CapturedData) GetAccountID() types.AccountID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AccountId
return c.accountID
}
// SetOrigin safely sets the response origin
// SetOrigin sets the response origin.
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
c.mu.Lock()
defer c.mu.Unlock()
c.Origin = origin
c.origin = origin
}
// GetOrigin safely gets the response origin
// GetOrigin returns the response origin.
func (c *CapturedData) GetOrigin() ResponseOrigin {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Origin
return c.origin
}
// SetClientIP safely sets the resolved client IP.
// SetClientIP sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
c.clientIP = ip
}
// GetClientIP safely gets the resolved client IP.
// GetClientIP returns the resolved client IP.
func (c *CapturedData) GetClientIP() netip.Addr {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
return c.clientIP
}
// SetUserID safely sets the authenticated user ID.
// SetUserID sets the authenticated user ID.
func (c *CapturedData) SetUserID(userID string) {
c.mu.Lock()
defer c.mu.Unlock()
c.UserID = userID
c.userID = userID
}
// GetUserID safely gets the authenticated user ID.
// GetUserID returns the authenticated user ID.
func (c *CapturedData) GetUserID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.UserID
return c.userID
}
// SetAuthMethod safely sets the authentication method used.
// SetAuthMethod sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()
defer c.mu.Unlock()
c.AuthMethod = method
c.authMethod = method
}
// GetAuthMethod safely gets the authentication method used.
// GetAuthMethod returns the authentication method used.
func (c *CapturedData) GetAuthMethod() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AuthMethod
return c.authMethod
}
// WithCapturedData adds a CapturedData struct to the context
// WithCapturedData adds a CapturedData struct to the context.
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
return context.WithValue(ctx, capturedDataKey, data)
}
// CapturedDataFromContext retrieves the CapturedData from context
// CapturedDataFromContext retrieves the CapturedData from context.
func CapturedDataFromContext(ctx context.Context) *CapturedData {
v := ctx.Value(capturedDataKey)
data, ok := v.(*CapturedData)
@@ -161,28 +164,3 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
}
return data
}
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(types.ServiceID)
if !ok {
return ""
}
return serviceId
}
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
return context.WithValue(ctx, accountIdKey, accountId)
}
func AccountIdFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIdKey)
accountId, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountId
}

View File

@@ -66,19 +66,16 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), result.serviceID)
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx := r.Context()
// Set the account ID in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// Populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(result.serviceID)
capturedData.SetAccountId(result.accountID)
capturedData.SetServiceID(result.serviceID)
capturedData.SetAccountID(result.accountID)
}
pt := result.target
@@ -96,10 +93,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders),
Transport: p.transport,
FlushInterval: -1,
ErrorHandler: proxyErrorHandler,
ErrorHandler: p.proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
@@ -113,7 +110,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
// The pathRewrite parameter controls how the request path is transformed.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string, stripAuthHeaders []string) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
switch pathRewrite {
case PathRewritePreserve:
@@ -137,6 +134,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Host = target.Host
}
for _, h := range stripAuthHeaders {
r.Out.Header.Del(h)
}
for k, v := range customHeaders {
r.Out.Header.Set(k, v)
}
@@ -305,7 +306,7 @@ func extractForwardedPort(host, resolvedProto string) string {
// proxyErrorHandler handles errors from the reverse proxy and serves
// user-friendly error pages instead of raw error responses.
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
func (p *ReverseProxy) proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginProxyError)
}
@@ -313,7 +314,7 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
clientIP := getClientIP(r)
title, message, code, status := classifyProxyError(err)
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
p.logger.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
web.ServeErrorPage(w, r, code, title, message, requestID, status)

View File

@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr)
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -551,7 +551,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
t.Run("preserve keeps full request path", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil, nil)
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
rewrite(pr)
@@ -561,7 +561,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
})
t.Run("preserve with root matchedPath", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil, nil)
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
rewrite(pr)
@@ -579,7 +579,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
"X-Custom-Auth": "token-abc",
"X-Env": "production",
}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -589,7 +589,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
})
t.Run("nil customHeaders is fine", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -599,7 +599,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
t.Run("custom headers override existing request headers", func(t *testing.T) {
headers := map[string]string{"X-Override": "new-value"}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("X-Override", "old-value")
@@ -609,11 +609,38 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
})
}
func TestRewriteFunc_StripsAuthorizationHeader(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("strips incoming Authorization when no custom Authorization set", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, []string{"Authorization"})
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("Authorization", "Bearer proxy-token")
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("Authorization"), "Authorization should be stripped")
})
t.Run("custom Authorization replaces incoming", func(t *testing.T) {
headers := map[string]string{"Authorization": "Basic YmFja2VuZDpzZWNyZXQ="}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, []string{"Authorization"})
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("Authorization", "Bearer proxy-token")
rewrite(pr)
assert.Equal(t, "Basic YmFja2VuZDpzZWNyZXQ=", pr.Out.Header.Get("Authorization"),
"backend Authorization from custom headers should be set")
})
}
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}, nil)
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
rewrite(pr)

View File

@@ -38,6 +38,11 @@ type Mapping struct {
Paths map[string]*PathTarget
PassHostHeader bool
RewriteRedirects bool
// StripAuthHeaders are header names used for header-based auth.
// These headers are stripped from requests before forwarding.
StripAuthHeaders []string
// sortedPaths caches the paths sorted by length (longest first).
sortedPaths []string
}
type targetResult struct {
@@ -47,6 +52,7 @@ type targetResult struct {
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
stripAuthHeaders []string
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
@@ -65,16 +71,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
return targetResult{}, false
}
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
for _, path := range paths {
for _, path := range m.sortedPaths {
if strings.HasPrefix(req.URL.Path, path) {
pt := m.Paths[path]
if pt == nil || pt.URL == nil {
@@ -89,6 +86,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
rewriteRedirects: m.RewriteRedirects,
stripAuthHeaders: m.StripAuthHeaders,
}, true
}
}
@@ -96,7 +94,18 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
return targetResult{}, false
}
// AddMapping registers a host-to-backend mapping for the reverse proxy.
func (p *ReverseProxy) AddMapping(m Mapping) {
// Sort paths longest-first to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
m.sortedPaths = paths
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
p.mappings[m.Host] = m

View File

@@ -0,0 +1,183 @@
// Package restrict provides connection-level access control based on
// IP CIDR ranges and geolocation (country codes).
package restrict
import (
"net/netip"
"slices"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
)
// GeoResolver resolves an IP address to geographic information.
type GeoResolver interface {
LookupAddr(addr netip.Addr) geolocation.Result
Available() bool
}
// Filter evaluates IP restrictions. CIDR checks are performed first
// (cheap), followed by country lookups (more expensive) only when needed.
type Filter struct {
AllowedCIDRs []netip.Prefix
BlockedCIDRs []netip.Prefix
AllowedCountries []string
BlockedCountries []string
}
// ParseFilter builds a Filter from the raw string slices. Returns nil
// if all slices are empty.
func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter {
if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 &&
len(allowedCountries) == 0 && len(blockedCountries) == 0 {
return nil
}
f := &Filter{
AllowedCountries: normalizeCountryCodes(allowedCountries),
BlockedCountries: normalizeCountryCodes(blockedCountries),
}
for _, cidr := range allowedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err)
continue
}
f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked())
}
for _, cidr := range blockedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err)
continue
}
f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked())
}
return f
}
func normalizeCountryCodes(codes []string) []string {
if len(codes) == 0 {
return nil
}
out := make([]string, len(codes))
for i, c := range codes {
out[i] = strings.ToUpper(c)
}
return out
}
// Verdict is the result of an access check.
type Verdict int
const (
// Allow indicates the address passed all checks.
Allow Verdict = iota
// DenyCIDR indicates the address was blocked by a CIDR rule.
DenyCIDR
// DenyCountry indicates the address was blocked by a country rule.
DenyCountry
// DenyGeoUnavailable indicates that country restrictions are configured
// but the geo lookup is unavailable.
DenyGeoUnavailable
)
// String returns the deny reason string matching the HTTP auth mechanism names.
func (v Verdict) String() string {
switch v {
case Allow:
return "allow"
case DenyCIDR:
return "ip_restricted"
case DenyCountry:
return "country_restricted"
case DenyGeoUnavailable:
return "geo_unavailable"
default:
return "unknown"
}
}
// Check evaluates whether addr is permitted. CIDR rules are evaluated
// first because they are O(n) prefix comparisons. Country rules run
// only when CIDR checks pass and require a geo lookup.
func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict {
if f == nil {
return Allow
}
// Normalize v4-mapped-v6 (e.g. ::ffff:10.1.2.3) to plain v4 so that
// IPv4 CIDR rules match regardless of how the address was received.
addr = addr.Unmap()
if v := f.checkCIDR(addr); v != Allow {
return v
}
return f.checkCountry(addr, geo)
}
func (f *Filter) checkCIDR(addr netip.Addr) Verdict {
if len(f.AllowedCIDRs) > 0 {
allowed := false
for _, prefix := range f.AllowedCIDRs {
if prefix.Contains(addr) {
allowed = true
break
}
}
if !allowed {
return DenyCIDR
}
}
for _, prefix := range f.BlockedCIDRs {
if prefix.Contains(addr) {
return DenyCIDR
}
}
return Allow
}
func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict {
if len(f.AllowedCountries) == 0 && len(f.BlockedCountries) == 0 {
return Allow
}
if geo == nil || !geo.Available() {
return DenyGeoUnavailable
}
result := geo.LookupAddr(addr)
if result.CountryCode == "" {
// Unknown country: deny if an allowlist is active, allow otherwise.
// Blocklists are best-effort: unknown countries pass through since
// the default policy is allow.
if len(f.AllowedCountries) > 0 {
return DenyCountry
}
return Allow
}
if len(f.AllowedCountries) > 0 {
if !slices.Contains(f.AllowedCountries, result.CountryCode) {
return DenyCountry
}
}
if slices.Contains(f.BlockedCountries, result.CountryCode) {
return DenyCountry
}
return Allow
}
// HasRestrictions returns true if any restriction rules are configured.
func (f *Filter) HasRestrictions() bool {
if f == nil {
return false
}
return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 ||
len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0
}

View File

@@ -0,0 +1,278 @@
package restrict
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
)
type mockGeo struct {
countries map[string]string
}
func (m *mockGeo) LookupAddr(addr netip.Addr) geolocation.Result {
return geolocation.Result{CountryCode: m.countries[addr.String()]}
}
func (m *mockGeo) Available() bool { return true }
func newMockGeo(entries map[string]string) *mockGeo {
return &mockGeo{countries: entries}
}
func TestFilter_Check_NilFilter(t *testing.T) {
var f *Filter
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
}
func TestFilter_Check_AllowedCIDR(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_BlockedCIDR(t *testing.T) {
f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil)
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "not in allowlist")
}
func TestFilter_Check_AllowedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
f := ParseFilter(nil, nil, []string{"US", "DE"}, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
}
func TestFilter_Check_BlockedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "CN",
"2.2.2.2": "RU",
"3.3.3.3": "US",
})
f := ParseFilter(nil, nil, nil, []string{"CN", "RU"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "US not in blocklist")
}
func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
// Allow US and DE, but block DE explicitly.
f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"})
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
}
func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
})
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active")
}
func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "CN",
})
f := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active")
}
func TestFilter_Check_CountryWithoutGeo(t *testing.T) {
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist")
}
func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) {
f := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist")
}
func TestFilter_Check_GeoUnavailable(t *testing.T) {
geo := &unavailableGeo{}
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist")
f2 := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist")
}
func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
// CIDR-only filter should never touch geo, so nil geo is fine.
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) {
geo := newMockGeo(map[string]string{
"10.1.2.3": "CN",
"10.2.3.4": "US",
})
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), geo), "CIDR denied before country check")
}
func TestParseFilter_Empty(t *testing.T) {
f := ParseFilter(nil, nil, nil, nil)
assert.Nil(t, f)
}
func TestParseFilter_InvalidCIDR(t *testing.T) {
f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil)
assert.NotNil(t, f)
assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped")
assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), f.AllowedCIDRs[0])
}
func TestFilter_HasRestrictions(t *testing.T) {
assert.False(t, (*Filter)(nil).HasRestrictions())
assert.False(t, (&Filter{}).HasRestrictions())
assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions())
assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions())
}
func TestFilter_Check_IPv6CIDR(t *testing.T) {
f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 addr not in v6 allowlist")
}
func TestFilter_Check_IPv4MappedIPv6(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
// A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR.
v4mapped := netip.MustParseAddr("::ffff:10.1.2.3")
assert.True(t, v4mapped.Is4In6(), "precondition: address is v4-in-v6")
assert.Equal(t, Allow, f.Check(v4mapped, nil), "v4-mapped-v6 must match v4 CIDR after Unmap")
v4mappedOutside := netip.MustParseAddr("::ffff:192.168.1.1")
assert.Equal(t, DenyCIDR, f.Check(v4mappedOutside, nil), "v4-mapped-v6 outside v4 CIDR")
}
func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("::ffff:10.1.2.3"), nil), "v4-mapped matches v4 CIDR")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "v4 not in either CIDR")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("fe80::1"), nil), "v6 not in either CIDR")
}
func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) {
// 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24.
f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil)
assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0])
// Verify it still matches correctly.
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.100"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("1.1.2.1"), nil))
}
func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
tests := []struct {
name string
allowedCountries []string
blockedCountries []string
addr string
want Verdict
}{
{
name: "lowercase allowlist matches uppercase MaxMind code",
allowedCountries: []string{"us", "de"},
addr: "1.1.1.1",
want: Allow,
},
{
name: "mixed-case allowlist matches",
allowedCountries: []string{"Us", "dE"},
addr: "2.2.2.2",
want: Allow,
},
{
name: "lowercase allowlist rejects non-matching country",
allowedCountries: []string{"us", "de"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "lowercase blocklist blocks matching country",
blockedCountries: []string{"cn"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "mixed-case blocklist blocks matching country",
blockedCountries: []string{"Cn"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "lowercase blocklist does not block non-matching country",
blockedCountries: []string{"cn"},
addr: "1.1.1.1",
want: Allow,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries)
got := f.Check(netip.MustParseAddr(tc.addr), geo)
assert.Equal(t, tc.want, got)
})
}
}
// unavailableGeo simulates a GeoResolver whose database is not loaded.
type unavailableGeo struct{}
func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} }
func (u *unavailableGeo) Available() bool { return false }

View File

@@ -7,12 +7,14 @@ import (
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -20,6 +22,10 @@ import (
// timeout is configured.
const defaultDialTimeout = 30 * time.Second
// errAccessRestricted is returned by relayTCP for access restriction
// denials so callers can skip warn-level logging (already logged at debug).
var errAccessRestricted = errors.New("rejected by access restrictions")
// SNIHost is a typed key for SNI hostname lookups.
type SNIHost string
@@ -64,6 +70,11 @@ type Route struct {
// DialTimeout overrides the default dial timeout for this route.
// Zero uses defaultDialTimeout.
DialTimeout time.Duration
// SessionIdleTimeout overrides the default idle timeout for relay connections.
// Zero uses DefaultIdleTimeout.
SessionIdleTimeout time.Duration
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
}
// l4Logger sends layer-4 access log entries to the management server.
@@ -99,6 +110,7 @@ type Router struct {
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context
@@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener {
// stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
func (r *Router) AddRoute(host SNIHost, route Route) {
host = SNIHost(strings.ToLower(string(host)))
if host == "" {
return
}
@@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) {
// Active relay connections for the service are closed immediately.
// If other routes remain for the host, they are preserved.
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
host = SNIHost(strings.ToLower(string(host)))
r.mu.Lock()
defer r.mu.Unlock()
@@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
return
}
host := SNIHost(sni)
host := SNIHost(strings.ToLower(sni))
route, ok := r.lookupRoute(host)
if !ok {
r.handleUnmatched(ctx, wrapped)
@@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
}
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
}
_ = wrapped.Close()
}
}
@@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
if fb != nil {
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
}
_ = conn.Close()
}
return
@@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
}
}
// SetGeo sets the geolocation lookup used for country-based restrictions.
func (r *Router) SetGeo(geo restrict.GeoResolver) {
r.mu.Lock()
defer r.mu.Unlock()
r.geo = geo
}
// checkRestrictions evaluates the route's access filter against the
// connection's remote address. Returns Allow if the connection is
// permitted, or a deny verdict indicating the reason.
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
if route.Filter == nil {
return restrict.Allow
}
addr, err := addrFromConn(conn)
if err != nil {
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
return restrict.DenyCIDR
}
r.mu.RLock()
geo := r.geo
r.mu.RUnlock()
return route.Filter.Check(addr, geo)
}
// relayTCP sets up and runs a bidirectional TCP relay.
// The caller owns conn and must close it if this method returns an error.
// On success (nil error), both conn and backend are closed by the relay.
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
r.logL4Deny(route, conn, verdict)
return errAccessRestricted
}
svcCtx, err := r.acquireRelay(ctx, route)
if err != nil {
return err
@@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
})
entry.Debug("TCP relay started")
idleTimeout := route.SessionIdleTimeout
if idleTimeout <= 0 {
idleTimeout = DefaultIdleTimeout
}
start := time.Now()
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
elapsed := time.Since(start)
if obs != nil {
@@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
return
}
var sourceIP netip.Addr
if remote := conn.RemoteAddr(); remote != nil {
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
@@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
})
}
// logL4Deny sends an access log entry for a denied connection.
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
r.mu.RLock()
al := r.accessLog
r.mu.RUnlock()
if al == nil {
return
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
ServiceID: route.ServiceID,
Protocol: route.Protocol,
Host: route.Domain,
SourceIP: sourceIP,
DenyReason: verdict.String(),
})
}
// getOrCreateServiceCtxLocked returns the context for a service, creating one
// if it doesn't exist yet. The context is a child of the server context.
// Must be called with mu held.
@@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types
r.svcCancels[svcID] = cancel
return ctx
}
// addrFromConn extracts a netip.Addr from a connection's remote address.
func addrFromConn(conn net.Conn) (netip.Addr, error) {
remote := conn.RemoteAddr()
if remote == nil {
return netip.Addr{}, errors.New("no remote address")
}
ap, err := netip.ParseAddrPort(remote.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -1668,3 +1669,73 @@ func startEchoPlain(t *testing.T) net.Listener {
return ln
}
// fakeAddr implements net.Addr with a custom string representation.
type fakeAddr string
func (f fakeAddr) Network() string { return "tcp" }
func (f fakeAddr) String() string { return string(f) }
// fakeConn is a minimal net.Conn with a controllable RemoteAddr.
type fakeConn struct {
net.Conn
remote net.Addr
}
func (f *fakeConn) RemoteAddr() net.Addr { return f.remote }
func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied")
}
func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: nil}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied")
}
func TestCheckRestrictions_AllowedAndDenied(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist")
denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist")
}
func TestCheckRestrictions_NilFilter(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
route := Route{Filter: nil}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything")
}
func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
// net.IPv4() returns a 16-byte v4-in-v6 representation internally.
// The restriction check must Unmap it to match the v4 CIDR.
conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR")
// Explicitly v4-mapped-v6 address string.
conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR")
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -67,6 +68,8 @@ type Relay struct {
dialTimeout time.Duration
sessionTTL time.Duration
maxSessions int
filter *restrict.Filter
geo restrict.GeoResolver
mu sync.RWMutex
sessions map[clientAddr]*session
@@ -114,6 +117,10 @@ type RelayConfig struct {
SessionTTL time.Duration
MaxSessions int
AccessLog l4Logger
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
// Geo is the geolocation lookup used for country-based restrictions.
Geo restrict.GeoResolver
}
// New creates a UDP relay for the given listener and backend target.
@@ -146,6 +153,8 @@ func New(parentCtx context.Context, cfg RelayConfig) *Relay {
dialTimeout: dialTimeout,
sessionTTL: sessionTTL,
maxSessions: maxSessions,
filter: cfg.Filter,
geo: cfg.Geo,
sessions: make(map[clientAddr]*session),
bufPool: sync.Pool{
New: func() any {
@@ -166,9 +175,18 @@ func (r *Relay) ServiceID() types.ServiceID {
// SetObserver sets the session lifecycle observer. Must be called before Serve.
func (r *Relay) SetObserver(obs SessionObserver) {
r.mu.Lock()
defer r.mu.Unlock()
r.observer = obs
}
// getObserver returns the current session lifecycle observer.
func (r *Relay) getObserver() SessionObserver {
r.mu.RLock()
defer r.mu.RUnlock()
return r.observer
}
// Serve starts the relay loop. It blocks until the context is canceled
// or the listener is closed.
func (r *Relay) Serve() {
@@ -209,8 +227,8 @@ func (r *Relay) Serve() {
}
sess.bytesIn.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
if obs := r.getObserver(); obs != nil {
obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
}
r.bufPool.Put(bufp)
}
@@ -234,6 +252,10 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
return nil, r.ctx.Err()
}
if err := r.checkAccessRestrictions(addr); err != nil {
return nil, err
}
r.mu.Lock()
if sess, ok = r.sessions[key]; ok && sess != nil {
@@ -248,16 +270,16 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
if len(r.sessions) >= r.maxSessions {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
}
if !r.sessLimiter.Allow() {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session creation rate limited")
}
@@ -274,8 +296,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
r.mu.Lock()
delete(r.sessions, key)
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionDialError(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionDialError(r.accountID)
}
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
}
@@ -293,8 +315,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
r.sessions[key] = sess
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionStarted(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionStarted(r.accountID)
}
r.sessWg.Go(func() {
@@ -305,6 +327,21 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
return sess, nil
}
func (r *Relay) checkAccessRestrictions(addr net.Addr) error {
if r.filter == nil {
return nil
}
clientIP, err := addrFromUDPAddr(addr)
if err != nil {
return fmt.Errorf("parse client address %s for restriction check: %w", addr, err)
}
if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow {
r.logDeny(clientIP, v)
return fmt.Errorf("access restricted for %s", addr)
}
return nil
}
// relayBackendToClient reads packets from the backend and writes them
// back to the client through the public-facing listener.
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
@@ -332,8 +369,8 @@ func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
}
sess.bytesOut.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
if obs := r.getObserver(); obs != nil {
obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
}
}
}
@@ -402,9 +439,10 @@ func (r *Relay) cleanupIdleSessions() {
}
r.mu.Unlock()
obs := r.getObserver()
for _, sess := range expired {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
if obs != nil {
obs.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
@@ -429,8 +467,8 @@ func (r *Relay) removeSession(sess *session) {
if removed {
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
@@ -459,6 +497,22 @@ func (r *Relay) logSessionEnd(sess *session) {
})
}
// logDeny sends an access log entry for a denied UDP packet.
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) {
if r.accessLog == nil {
return
}
r.accessLog.LogL4(accesslog.L4Entry{
AccountID: r.accountID,
ServiceID: r.serviceID,
Protocol: accesslog.ProtocolUDP,
Host: r.domain,
SourceIP: clientIP,
DenyReason: verdict.String(),
})
}
// Close stops the relay, waits for all session goroutines to exit,
// and cleans up remaining sessions.
func (r *Relay) Close() {
@@ -485,12 +539,22 @@ func (r *Relay) Close() {
}
r.mu.Unlock()
obs := r.getObserver()
for _, sess := range closedSessions {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
if obs != nil {
obs.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
r.sessWg.Wait()
}
// addrFromUDPAddr extracts a netip.Addr from a net.Addr.
func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) {
ap, err := netip.ParseAddrPort(addr.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}

View File

@@ -490,7 +490,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
logger := log.New()
logger.SetLevel(log.WarnLevel)
authMw := auth.NewMiddleware(logger, nil)
authMw := auth.NewMiddleware(logger, nil, nil)
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
clusterAddress := "test.proxy.io"
@@ -511,6 +511,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
0,
proxytypes.AccountID(mapping.GetAccountId()),
proxytypes.ServiceID(mapping.GetId()),
nil,
)
require.NoError(t, err)

View File

@@ -43,12 +43,14 @@ import (
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s"
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -59,7 +61,6 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// portRouter bundles a per-port Router with its listener and cancel func.
type portRouter struct {
router *nbtcp.Router
@@ -95,6 +96,9 @@ type Server struct {
// so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them.
hijackTracker conntrack.HijackTracker
// geo resolves IP addresses to country/city for access restrictions and access logs.
geo restrict.GeoResolver
geoRaw *geolocation.Lookup
// routerReady is closed once mainRouter is fully initialized.
// The mapping worker waits on this before processing updates.
@@ -159,10 +163,38 @@ type Server struct {
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
// ports for TCP/UDP/TLS services.
SupportsCustomPorts bool
// DefaultDialTimeout is the default timeout for establishing backend
// connections when no per-service timeout is configured. Zero means
// each transport uses its own hardcoded default (typically 30s).
DefaultDialTimeout time.Duration
// MaxDialTimeout caps the per-service backend dial timeout.
// When the API sends a timeout, it is clamped to this value.
// When the API sends no timeout, this value is used as the default.
// Zero means no cap (the proxy honors whatever management sends).
MaxDialTimeout time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files for
// country-based access restrictions. Empty disables geo lookups.
GeoDataDir string
// MaxSessionIdleTimeout caps the per-service session idle timeout.
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
return s.MaxSessionIdleTimeout
}
return d
}
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
// If d is zero, MaxDialTimeout is used as the default.
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
if s.MaxDialTimeout <= 0 {
return d
}
if d <= 0 || d > s.MaxDialTimeout {
return s.MaxDialTimeout
}
return d
}
// NotifyStatus sends a status update to management about tunnel connectivity.
@@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
@@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
tlsConfig, err := s.configureTLS(ctx)
if err != nil {
return err
@@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
var startupOK bool
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.startDebugEndpoint()
if err := s.startHealthServer(); err != nil {
@@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
}
startupOK = true
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
@@ -691,6 +749,16 @@ func (s *Server) shutdownServices() {
s.portRouterWg.Wait()
wg.Wait()
if s.accessLog != nil {
s.accessLog.Close()
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation: %v", err)
}
}
}
// resolveDialFunc returns a DialContextFunc that dials through the
@@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TCP port %d: %w", port, err)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.SetFallback(nbtcp.Route{
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
})
s.portMu.Lock()
@@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("empty target address for UDP service %s", svcID)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
}
@@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
})
if tlsPort != s.mainPort {
@@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
}
}
// parseRestrictions converts a proto mapping's access restrictions into
// a restrict.Filter. Returns nil if the mapping has no restrictions.
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
r := mapping.GetAccessRestrictions()
if r == nil {
return nil
}
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
}
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
// but the proxy has no geolocation database loaded. All requests to this
// service will be denied at runtime (fail-close).
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
if r == nil {
return
}
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
return
}
if s.geo != nil && s.geo.Available() {
return
}
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
}
// l4TargetAddress extracts and validates the target address from a mapping's
// first path entry. Returns empty string if no paths exist or the address is
// not a valid host:port.
@@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
}
// l4DialTimeout returns the dial timeout from the first target's options,
// falling back to the server's DefaultDialTimeout.
// clamped to MaxDialTimeout.
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
paths := mapping.GetPath()
if len(paths) > 0 {
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
return d.AsDuration()
return s.clampDialTimeout(d.AsDuration())
}
}
return s.DefaultDialTimeout
return s.clampDialTimeout(0)
}
// l4SessionIdleTimeout returns the configured session idle timeout from the
@@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
dialFn, err := s.resolveDialFunc(accountID)
if err != nil {
_ = listener.Close()
if err := listener.Close(); err != nil {
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
}
return fmt.Errorf("resolve dialer for UDP: %w", err)
}
@@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
ServiceID: svcID,
DialFunc: dialFn,
DialTimeout: s.l4DialTimeout(mapping),
SessionTTL: l4SessionIdleTimeout(mapping),
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
AccessLog: s.accessLog,
Filter: parseRestrictions(mapping),
Geo: s.geo,
})
relay.SetObserver(s.meter)
@@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
if mapping.GetAuth().GetOidc() {
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
}
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
}
ipRestrictions := parseRestrictions(mapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil {
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
}
m := s.protoToMapping(ctx, mapping)
@@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
pt.RequestTimeout = d.AsDuration()
}
}
if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 {
pt.RequestTimeout = s.DefaultDialTimeout
}
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
paths[pathMapping.GetPath()] = pt
}
return proxy.Mapping{
m := proxy.Mapping{
ID: types.ServiceID(mapping.GetId()),
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
@@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
PassHostHeader: mapping.GetPassHostHeader(),
RewriteRedirects: mapping.GetRewriteRedirects(),
}
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
}
return m
}
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {

View File

@@ -2826,6 +2826,10 @@ components:
type: string
description: "City name from geolocation"
example: "San Francisco"
subdivision_code:
type: string
description: "First-level administrative subdivision ISO code (e.g. state/province)"
example: "CA"
bytes_upload:
type: integer
format: int64
@@ -2952,26 +2956,32 @@ components:
id:
type: string
description: Service ID
example: "cs8i4ug6lnn4g9hqv7mg"
name:
type: string
description: Service name
example: "myapp.example.netbird.app"
domain:
type: string
description: Domain for the service
example: "myapp.example.netbird.app"
mode:
type: string
description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough.
enum: [http, tcp, udp, tls]
default: http
example: "http"
listen_port:
type: integer
minimum: 0
maximum: 65535
description: Port the proxy listens on (L4/TLS only)
example: 8443
port_auto_assigned:
type: boolean
description: Whether the listen port was auto-assigned
readOnly: true
example: false
proxy_cluster:
type: string
description: The proxy cluster handling this service (derived from domain)
@@ -2984,14 +2994,19 @@ components:
enabled:
type: boolean
description: Whether the service is enabled
example: true
pass_host_header:
type: boolean
description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
example: false
rewrite_redirects:
type: boolean
description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
example: false
auth:
$ref: '#/components/schemas/ServiceAuthConfig'
access_restrictions:
$ref: '#/components/schemas/AccessRestrictions'
meta:
$ref: '#/components/schemas/ServiceMeta'
required:
@@ -3035,19 +3050,23 @@ components:
name:
type: string
description: Service name
example: "myapp.example.netbird.app"
domain:
type: string
description: Domain for the service
example: "myapp.example.netbird.app"
mode:
type: string
description: Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough.
enum: [http, tcp, udp, tls]
default: http
example: "http"
listen_port:
type: integer
minimum: 0
maximum: 65535
description: Port the proxy listens on (L4/TLS only). Set to 0 for auto-assignment.
example: 5432
targets:
type: array
items:
@@ -3057,14 +3076,19 @@ components:
type: boolean
description: Whether the service is enabled
default: true
example: true
pass_host_header:
type: boolean
description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
example: false
rewrite_redirects:
type: boolean
description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
example: false
auth:
$ref: '#/components/schemas/ServiceAuthConfig'
access_restrictions:
$ref: '#/components/schemas/AccessRestrictions'
required:
- name
- domain
@@ -3075,13 +3099,16 @@ components:
skip_tls_verify:
type: boolean
description: Skip TLS certificate verification for this backend
example: false
request_timeout:
type: string
description: Per-target response timeout as a Go duration string (e.g. "30s", "2m")
example: "30s"
path_rewrite:
type: string
description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
enum: [preserve]
example: "preserve"
custom_headers:
type: object
description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
@@ -3091,40 +3118,50 @@ components:
additionalProperties:
type: string
pattern: '^[^\r\n]*$'
example: {"X-Custom-Header": "value"}
proxy_protocol:
type: boolean
description: Send PROXY Protocol v2 header to this backend (TCP/TLS only)
example: false
session_idle_timeout:
type: string
description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m.
description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m").
example: "2m"
ServiceTarget:
type: object
properties:
target_id:
type: string
description: Target ID
example: "cs8i4ug6lnn4g9hqv7mg"
target_type:
type: string
description: Target type
enum: [peer, host, domain, subnet]
example: "subnet"
path:
type: string
description: URL path prefix for this target (HTTP only)
example: "/"
protocol:
type: string
description: Protocol to use when connecting to the backend
enum: [http, https, tcp, udp]
example: "http"
host:
type: string
description: Backend ip or domain for this target
example: "10.10.0.1"
port:
type: integer
minimum: 1
maximum: 65535
description: Backend port for this target
example: 8080
enabled:
type: boolean
description: Whether this target is enabled
example: true
options:
$ref: '#/components/schemas/ServiceTargetOptions'
required:
@@ -3144,15 +3181,73 @@ components:
$ref: '#/components/schemas/BearerAuthConfig'
link_auth:
$ref: '#/components/schemas/LinkAuthConfig'
header_auths:
type: array
items:
$ref: '#/components/schemas/HeaderAuthConfig'
HeaderAuthConfig:
type: object
description: Static header-value authentication. The proxy checks that the named header matches the configured value.
properties:
enabled:
type: boolean
description: Whether header auth is enabled
example: true
header:
type: string
description: HTTP header name to check (e.g. "Authorization", "X-API-Key")
example: "X-API-Key"
value:
type: string
description: Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses.
example: "my-secret-api-key"
required:
- enabled
- header
- value
AccessRestrictions:
type: object
description: Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
properties:
allowed_cidrs:
type: array
items:
type: string
format: cidr
example: "192.168.1.0/24"
description: CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed.
blocked_cidrs:
type: array
items:
type: string
format: cidr
example: "10.0.0.0/8"
description: CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs.
allowed_countries:
type: array
items:
type: string
pattern: '^[a-zA-Z]{2}$'
example: "US"
description: ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted.
blocked_countries:
type: array
items:
type: string
pattern: '^[a-zA-Z]{2}$'
example: "DE"
description: ISO 3166-1 alpha-2 country codes to block.
PasswordAuthConfig:
type: object
properties:
enabled:
type: boolean
description: Whether password auth is enabled
example: true
password:
type: string
description: Auth password
example: "s3cret"
required:
- enabled
- password
@@ -3162,9 +3257,11 @@ components:
enabled:
type: boolean
description: Whether PIN auth is enabled
example: false
pin:
type: string
description: PIN value
example: "1234"
required:
- enabled
- pin
@@ -3174,10 +3271,12 @@ components:
enabled:
type: boolean
description: Whether bearer auth is enabled
example: true
distribution_groups:
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv7mg"
description: List of group IDs that can use bearer auth
required:
- enabled
@@ -3187,6 +3286,7 @@ components:
enabled:
type: boolean
description: Whether link auth is enabled
example: false
required:
- enabled
ProxyCluster:
@@ -3217,20 +3317,25 @@ components:
id:
type: string
description: Domain ID
example: "ds8i4ug6lnn4g9hqv7mg"
domain:
type: string
description: Domain name
example: "example.netbird.app"
validated:
type: boolean
description: Whether the domain has been validated
example: true
type:
$ref: '#/components/schemas/ReverseProxyDomainType'
target_cluster:
type: string
description: The proxy cluster this domain is validated against (only for custom domains)
example: "eu.proxy.netbird.io"
supports_custom_ports:
type: boolean
description: Whether the cluster supports binding arbitrary TCP/UDP ports
example: true
required:
- id
- domain
@@ -3242,9 +3347,11 @@ components:
domain:
type: string
description: Domain name
example: "myapp.example.com"
target_cluster:
type: string
description: The proxy cluster this domain should be validated against
example: "eu.proxy.netbird.io"
required:
- domain
- target_cluster

View File

@@ -1276,6 +1276,21 @@ func (e PutApiIntegrationsMspTenantsIdInviteJSONBodyValue) Valid() bool {
}
}
// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
type AccessRestrictions struct {
// AllowedCidrs CIDR allowlist. If non-empty, only IPs matching these CIDRs are allowed.
AllowedCidrs *[]string `json:"allowed_cidrs,omitempty"`
// AllowedCountries ISO 3166-1 alpha-2 country codes to allow. If non-empty, only these countries are permitted.
AllowedCountries *[]string `json:"allowed_countries,omitempty"`
// BlockedCidrs CIDR blocklist. Connections from these CIDRs are rejected. Evaluated after allowed_cidrs.
BlockedCidrs *[]string `json:"blocked_cidrs,omitempty"`
// BlockedCountries ISO 3166-1 alpha-2 country codes to block.
BlockedCountries *[]string `json:"blocked_countries,omitempty"`
}
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
// CityName Commonly used English name of the city
@@ -1988,6 +2003,18 @@ type GroupRequest struct {
Resources *[]Resource `json:"resources,omitempty"`
}
// HeaderAuthConfig Static header-value authentication. The proxy checks that the named header matches the configured value.
type HeaderAuthConfig struct {
// Enabled Whether header auth is enabled
Enabled bool `json:"enabled"`
// Header HTTP header name to check (e.g. "Authorization", "X-API-Key")
Header string `json:"header"`
// Value Expected header value. For Basic auth use "Basic base64(user:pass)". For Bearer use "Bearer token". Cleared in responses.
Value string `json:"value"`
}
// HuntressMatchAttributes Attribute conditions to match when approving agents
type HuntressMatchAttributes struct {
// DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus.
@@ -3324,6 +3351,9 @@ type ProxyAccessLog struct {
// StatusCode HTTP status code returned
StatusCode int `json:"status_code"`
// SubdivisionCode First-level administrative subdivision ISO code (e.g. state/province)
SubdivisionCode *string `json:"subdivision_code,omitempty"`
// Timestamp Timestamp when the request was made
Timestamp time.Time `json:"timestamp"`
@@ -3562,7 +3592,9 @@ type SentinelOneMatchAttributesNetworkStatus string
// Service defines model for Service.
type Service struct {
Auth ServiceAuthConfig `json:"auth"`
// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"`
Auth ServiceAuthConfig `json:"auth"`
// Domain Domain for the service
Domain string `json:"domain"`
@@ -3605,6 +3637,7 @@ type ServiceMode string
// ServiceAuthConfig defines model for ServiceAuthConfig.
type ServiceAuthConfig struct {
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty"`
HeaderAuths *[]HeaderAuthConfig `json:"header_auths,omitempty"`
LinkAuth *LinkAuthConfig `json:"link_auth,omitempty"`
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty"`
@@ -3627,7 +3660,9 @@ type ServiceMetaStatus string
// ServiceRequest defines model for ServiceRequest.
type ServiceRequest struct {
Auth *ServiceAuthConfig `json:"auth,omitempty"`
// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"`
Auth *ServiceAuthConfig `json:"auth,omitempty"`
// Domain Domain for the service
Domain string `json:"domain"`
@@ -3702,7 +3737,7 @@ type ServiceTargetOptions struct {
// RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m")
RequestTimeout *string `json:"request_timeout,omitempty"`
// SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). Maximum 10m.
// SessionIdleTimeout Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m").
SessionIdleTimeout *string `json:"session_idle_timeout,omitempty"`
// SkipTlsVerify Skip TLS certificate verification for this backend

File diff suppressed because it is too large Load Diff

View File

@@ -80,12 +80,27 @@ message PathMapping {
PathTargetOptions options = 3;
}
message HeaderAuth {
// Header name to check, e.g. "Authorization", "X-API-Key".
string header = 1;
// argon2id hash of the expected full header value.
string hashed_value = 2;
}
message Authentication {
string session_key = 1;
int64 max_session_age_seconds = 2;
bool password = 3;
bool pin = 4;
bool oidc = 5;
repeated HeaderAuth header_auths = 6;
}
message AccessRestrictions {
repeated string allowed_cidrs = 1;
repeated string blocked_cidrs = 2;
repeated string allowed_countries = 3;
repeated string blocked_countries = 4;
}
message ProxyMapping {
@@ -106,6 +121,7 @@ message ProxyMapping {
string mode = 10;
// For L4/TLS: the port the proxy listens on.
int32 listen_port = 11;
AccessRestrictions access_restrictions = 12;
}
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy.
@@ -141,9 +157,15 @@ message AuthenticateRequest {
oneof request {
PasswordRequest password = 3;
PinRequest pin = 4;
HeaderAuthRequest header_auth = 5;
}
}
message HeaderAuthRequest {
string header_value = 1;
string header_name = 2;
}
message PasswordRequest {
string password = 1;
}

View File

@@ -65,8 +65,8 @@ func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool {
}
entry := earlyMsg{
peerID: peerID,
msg: msg,
peerID: peerID,
msg: msg,
createdAt: time.Now(),
}
elem := b.order.PushBack(entry)