[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -10,7 +10,7 @@ FROM gcr.io/distroless/base:debug
COPY netbird-proxy /go/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
USER netbird:netbird
ENV HOME=/var/lib/netbird

View File

@@ -28,7 +28,7 @@ FROM gcr.io/distroless/base:debug
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
USER netbird:netbird
ENV HOME=/var/lib/netbird

View File

@@ -13,10 +13,11 @@ import (
type Method string
var (
const (
MethodPassword Method = "password"
MethodPIN Method = "pin"
MethodOIDC Method = "oidc"
MethodHeader Method = "header"
)
func (m Method) String() string {

View File

@@ -36,31 +36,33 @@ var (
var (
logLevel string
debugLogs bool
mgmtAddr string
addr string
proxyDomain string
defaultDialTimeout time.Duration
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wildcardCertDir string
wgPort uint16
proxyProtocol bool
preSharedKey string
supportsCustomPorts bool
debugLogs bool
mgmtAddr string
addr string
proxyDomain string
maxDialTimeout time.Duration
maxSessionIdleTimeout time.Duration
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wildcardCertDir string
wgPort uint16
proxyProtocol bool
preSharedKey string
supportsCustomPorts bool
geoDataDir string
)
var rootCmd = &cobra.Command{
@@ -99,7 +101,9 @@ func init() {
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)")
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
}
// Execute runs the root command.
@@ -177,17 +181,15 @@ func runServer(cmd *cobra.Command, args []string) error {
ProxyProtocol: proxyProtocol,
PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts,
DefaultDialTimeout: defaultDialTimeout,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
GeoDataDir: geoDataDir,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil {
logger.Error(err)
return err
}
return nil
return srv.ListenAndServe(ctx, addr)
}
func envBoolOrDefault(key string, def bool) bool {
@@ -197,6 +199,7 @@ func envBoolOrDefault(key string, def bool) bool {
}
parsed, err := strconv.ParseBool(v)
if err != nil {
log.Warnf("parse %s=%q: %v, using default %v", key, v, err, def)
return def
}
return parsed
@@ -217,6 +220,7 @@ func envUint16OrDefault(key string, def uint16) uint16 {
}
parsed, err := strconv.ParseUint(v, 10, 16)
if err != nil {
log.Warnf("parse %s=%q: %v, using default %d", key, v, err, def)
return def
}
return uint16(parsed)
@@ -229,6 +233,7 @@ func envDurationOrDefault(key string, def time.Duration) time.Duration {
}
parsed, err := time.ParseDuration(v)
if err != nil {
log.Warnf("parse %s=%q: %v, using default %s", key, v, err, def)
return def
}
return parsed

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
@@ -22,6 +23,16 @@ const (
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
logSendTimeout = 10 * time.Second
// denyCooldown is the min interval between deny log entries per service+reason
// to prevent flooding from denied connections (e.g. UDP packets from blocked IPs).
denyCooldown = 10 * time.Second
// maxDenyBuckets caps tracked deny rate-limit entries to bound memory under DDoS.
maxDenyBuckets = 10000
// maxLogWorkers caps concurrent gRPC send goroutines.
maxLogWorkers = 4096
)
type domainUsage struct {
@@ -38,6 +49,18 @@ type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
}
// denyBucketKey identifies a rate-limited deny log stream.
type denyBucketKey struct {
ServiceID types.ServiceID
Reason string
}
// denyBucket tracks rate-limited deny log entries.
type denyBucket struct {
lastLogged time.Time
suppressed int64
}
// Logger sends access log entries to the management server via gRPC.
type Logger struct {
client gRPCClient
@@ -47,7 +70,12 @@ type Logger struct {
usageMux sync.Mutex
domainUsage map[string]*domainUsage
denyMu sync.Mutex
denyBuckets map[denyBucketKey]*denyBucket
logSem chan struct{}
cleanupCancel context.CancelFunc
dropped atomic.Int64
}
// NewLogger creates a new access log Logger. The trustedProxies parameter
@@ -64,6 +92,8 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
logger: logger,
trustedProxies: trustedProxies,
domainUsage: make(map[string]*domainUsage),
denyBuckets: make(map[denyBucketKey]*denyBucket),
logSem: make(chan struct{}, maxLogWorkers),
cleanupCancel: cancel,
}
@@ -83,7 +113,7 @@ func (l *Logger) Close() {
type logEntry struct {
ID string
AccountID types.AccountID
ServiceId types.ServiceID
ServiceID types.ServiceID
Host string
Path string
DurationMs int64
@@ -91,7 +121,7 @@ type logEntry struct {
ResponseCode int32
SourceIP netip.Addr
AuthMechanism string
UserId string
UserID string
AuthSuccess bool
BytesUpload int64
BytesDownload int64
@@ -118,6 +148,10 @@ type L4Entry struct {
DurationMs int64
BytesUpload int64
BytesDownload int64
// DenyReason, when non-empty, indicates the connection was denied.
// Values match the HTTP auth mechanism strings: "ip_restricted",
// "country_restricted", "geo_unavailable".
DenyReason string
}
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
@@ -126,7 +160,7 @@ func (l *Logger) LogL4(entry L4Entry) {
le := logEntry{
ID: xid.New().String(),
AccountID: entry.AccountID,
ServiceId: entry.ServiceID,
ServiceID: entry.ServiceID,
Protocol: entry.Protocol,
Host: entry.Host,
SourceIP: entry.SourceIP,
@@ -134,10 +168,47 @@ func (l *Logger) LogL4(entry L4Entry) {
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
}
if entry.DenyReason != "" {
if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) {
return
}
le.AuthMechanism = entry.DenyReason
le.AuthSuccess = false
}
l.log(le)
l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload)
}
// allowDenyLog rate-limits deny log entries per service+reason combination.
func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool {
key := denyBucketKey{ServiceID: serviceID, Reason: reason}
now := time.Now()
l.denyMu.Lock()
defer l.denyMu.Unlock()
b, ok := l.denyBuckets[key]
if !ok {
if len(l.denyBuckets) >= maxDenyBuckets {
return false
}
l.denyBuckets[key] = &denyBucket{lastLogged: now}
return true
}
if now.Sub(b.lastLogged) >= denyCooldown {
if b.suppressed > 0 {
l.logger.Debugf("access restriction: suppressed %d deny log entries for %s (%s)", b.suppressed, serviceID, reason)
}
b.lastLogged = now
b.suppressed = 0
return true
}
b.suppressed++
return false
}
func (l *Logger) log(entry logEntry) {
// Fire off the log request in a separate routine.
// This increases the possibility of losing a log message
@@ -147,12 +218,21 @@ func (l *Logger) log(entry logEntry) {
// There is also a chance that log messages will arrive at
// the server out of order; however, the timestamp should
// allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
now := timestamppb.Now()
select {
case l.logSem <- struct{}{}:
default:
total := l.dropped.Add(1)
l.logger.Debugf("access log send dropped: worker limit reached (total dropped: %d)", total)
return
}
go func() {
defer func() { <-l.logSem }()
logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout)
defer cancel()
// Only OIDC sessions have a meaningful user identity.
if entry.AuthMechanism != auth.MethodOIDC.String() {
entry.UserId = ""
entry.UserID = ""
}
var sourceIP string
@@ -165,7 +245,7 @@ func (l *Logger) log(entry logEntry) {
LogId: entry.ID,
AccountId: string(entry.AccountID),
Timestamp: now,
ServiceId: string(entry.ServiceId),
ServiceId: string(entry.ServiceID),
Host: entry.Host,
Path: entry.Path,
DurationMs: entry.DurationMs,
@@ -173,7 +253,7 @@ func (l *Logger) log(entry logEntry) {
ResponseCode: entry.ResponseCode,
SourceIp: sourceIP,
AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId,
UserId: entry.UserID,
AuthSuccess: entry.AuthSuccess,
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
@@ -181,7 +261,7 @@ func (l *Logger) log(entry logEntry) {
},
}); err != nil {
l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId,
"service_id": entry.ServiceID,
"host": entry.Host,
"path": entry.Path,
"duration": entry.DurationMs,
@@ -189,7 +269,7 @@ func (l *Logger) log(entry logEntry) {
"response_code": entry.ResponseCode,
"source_ip": sourceIP,
"auth_mechanism": entry.AuthMechanism,
"user_id": entry.UserId,
"user_id": entry.UserID,
"auth_success": entry.AuthSuccess,
"error": err,
}).Error("Error sending access log on gRPC connection")
@@ -248,7 +328,7 @@ func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
}
}
// cleanupStaleUsage removes usage entries for domains that have been inactive.
// cleanupStaleUsage removes usage and deny-rate-limit entries that have been inactive.
func (l *Logger) cleanupStaleUsage(ctx context.Context) {
ticker := time.NewTicker(usageCleanupPeriod)
defer ticker.Stop()
@@ -258,20 +338,41 @@ func (l *Logger) cleanupStaleUsage(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
l.usageMux.Lock()
now := time.Now()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
l.usageMux.Unlock()
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
l.cleanupDomainUsage(now)
l.cleanupDenyBuckets(now)
}
}
}
func (l *Logger) cleanupDomainUsage(now time.Time) {
l.usageMux.Lock()
defer l.usageMux.Unlock()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
}
func (l *Logger) cleanupDenyBuckets(now time.Time) {
l.denyMu.Lock()
defer l.denyMu.Unlock()
removed := 0
for key, bucket := range l.denyBuckets {
if now.Sub(bucket.lastLogged) > usageInactiveWindow {
delete(l.denyBuckets, key)
removed++
}
}
if removed > 0 {
l.logger.Debugf("cleaned up %d stale deny rate-limit entries", removed)
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/proxy/web"
)
// Middleware wraps an HTTP handler to log access entries and resolve client IPs.
func (l *Logger) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip logging for internal proxy assets (CSS, JS, etc.)
@@ -47,8 +48,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
// Create a mutable struct to capture data from downstream handlers.
// We pass a pointer in the context - the pointer itself flows down immutably,
// but the struct it points to can be mutated by inner handlers.
capturedData := &proxy.CapturedData{RequestID: requestID}
capturedData := proxy.NewCapturedData(requestID)
capturedData.SetClientIP(sourceIp)
ctx := proxy.WithCapturedData(r.Context(), capturedData)
start := time.Now()
@@ -66,8 +68,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
entry := logEntry{
ID: requestID,
ServiceId: capturedData.GetServiceId(),
AccountID: capturedData.GetAccountId(),
ServiceID: capturedData.GetServiceID(),
AccountID: capturedData.GetAccountID(),
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
@@ -75,14 +77,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
ResponseCode: int32(sw.status),
SourceIP: sourceIp,
AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(),
UserID: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
BytesUpload: bytesUpload,
BytesDownload: bytesDownload,
Protocol: ProtocolHTTP,
}
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID())
l.log(entry)

View File

@@ -0,0 +1,69 @@
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ErrHeaderAuthFailed indicates that the header was present but the
// credential did not validate. Callers should return 401 instead of
// falling through to other auth schemes.
var ErrHeaderAuthFailed = errors.New("header authentication failed")
// Header implements header-based authentication. The proxy checks for the
// configured header in each request and validates its value via gRPC.
type Header struct {
id types.ServiceID
accountId types.AccountID
headerName string
client authenticator
}
// NewHeader creates a Header authentication scheme for the given header name.
func NewHeader(client authenticator, id types.ServiceID, accountId types.AccountID, headerName string) Header {
return Header{
id: id,
accountId: accountId,
headerName: headerName,
client: client,
}
}
// Type returns auth.MethodHeader.
func (Header) Type() auth.Method {
return auth.MethodHeader
}
// Authenticate checks for the configured header in the request. If absent,
// returns empty (unauthenticated). If present, validates via gRPC.
func (h Header) Authenticate(r *http.Request) (string, string, error) {
value := r.Header.Get(h.headerName)
if value == "" {
return "", "", nil
}
res, err := h.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: string(h.id),
AccountId: string(h.accountId),
Request: &proto.AuthenticateRequest_HeaderAuth{
HeaderAuth: &proto.HeaderAuthRequest{
HeaderValue: value,
HeaderName: h.headerName,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate header: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", "", ErrHeaderAuthFailed
}

View File

@@ -4,9 +4,12 @@ import (
"context"
"crypto/ed25519"
"encoding/base64"
"errors"
"fmt"
"html"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
@@ -16,11 +19,16 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto"
)
// errValidationUnavailable indicates that session validation failed due to
// an infrastructure error (e.g. gRPC unavailable), not an invalid token.
var errValidationUnavailable = errors.New("session validation unavailable")
type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
@@ -40,12 +48,14 @@ type Scheme interface {
Authenticate(*http.Request) (token string, promptData string, err error)
}
// DomainConfig holds the authentication and restriction settings for a protected domain.
type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
AccountID types.AccountID
ServiceID types.ServiceID
IPRestrictions *restrict.Filter
}
type validationResult struct {
@@ -54,17 +64,18 @@ type validationResult struct {
DeniedReason string
}
// Middleware applies per-domain authentication and IP restriction checks.
type Middleware struct {
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
sessionValidator SessionValidator
geo restrict.GeoResolver
}
// NewMiddleware creates a new authentication middleware.
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
// locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
// NewMiddleware creates a new authentication middleware. The sessionValidator is
// optional; if nil, OIDC session tokens are validated locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware {
if logger == nil {
logger = log.StandardLogger()
}
@@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl
domains: make(map[string]DomainConfig),
logger: logger,
sessionValidator: sessionValidator,
geo: geo,
}
}
// Protect applies authentication middleware to the passed handler.
// For each incoming request it will be checked against the middleware's
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
// Protect wraps next with per-domain authentication and IP restriction checks.
// Requests whose Host is not registered pass through unchanged.
func (mw *Middleware) Protect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
@@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(config.Schemes) == 0 {
if !exists {
next.ServeHTTP(w, r)
return
}
@@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
// Set account and service IDs in captured data for access logging.
setCapturedIDs(r, config)
if !mw.checkIPRestrictions(w, r, config) {
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
if mw.handleOAuthCallbackError(w, r) {
return
}
@@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
if mw.forwardWithHeaderAuth(w, r, host, config, next) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
@@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(config.AccountID)
cd.SetServiceId(config.ServiceID)
cd.SetAccountID(config.AccountID)
cd.SetServiceID(config.ServiceID)
}
}
// checkIPRestrictions validates the client IP against the domain's IP restrictions.
// Uses the resolved client IP from CapturedData (which accounts for trusted proxies)
// rather than r.RemoteAddr directly.
func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if config.IPRestrictions == nil {
return true
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
if verdict == restrict.Allow {
return true
}
reason := verdict.String()
mw.blockIPRestriction(r, reason)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr.
func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
if ip := cd.GetClientIP(); ip.IsValid() {
return ip
}
}
clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr)
if clientIPStr == "" {
clientIPStr = r.RemoteAddr
}
addr, err := netip.ParseAddr(clientIPStr)
if err != nil {
return netip.Addr{}
}
return addr.Unmap()
}
// blockIPRestriction sets captured data fields for an IP-restriction block event.
func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(reason)
}
mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr)
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
@@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
} else {
errDesc = html.EscapeString(errDesc)
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
@@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
return true
}
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
for _, scheme := range config.Schemes {
hdr, ok := scheme.(Header)
if !ok {
continue
}
handled := mw.tryHeaderScheme(w, r, host, config, hdr, next)
if handled {
return true
}
}
return false
}
func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool {
token, _, err := hdr.Authenticate(r)
if err != nil {
return mw.handleHeaderAuthError(w, r, err)
}
if token == "" {
return false
}
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "")
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return true
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetAuthMethod(auth.MethodHeader.String())
}
next.ServeHTTP(w, r)
return true
}
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return true
}
func setHeaderCapturedData(ctx context.Context, userID string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
}
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
@@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return
}
@@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
return
}
expiration := config.SessionExpiration
setSessionCookie(w, token, config.SessionExpiration)
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// setSessionCookie writes a session cookie with secure defaults.
func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) {
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
@@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
@@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
AccountID: accountID,
ServiceID: serviceID,
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
}
return nil
}
@@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
SessionExpiration: expiration,
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
}
return nil
}
// RemoveDomain unregisters authentication for the given domain.
func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
// validateSessionToken validates a session token, optionally checking group access via gRPC.
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
// validateSessionToken validates a session token. OIDC tokens with a configured
// validator go through gRPC for group access checks; other methods validate locally.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host,
SessionToken: token,
})
if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
return nil, fmt.Errorf("session validation failed")
return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err)
}
if !resp.Valid {
mw.logger.WithFields(log.Fields{
@@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
return &validationResult{UserID: resp.UserId, Valid: true}, nil
}
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err

View File

@@ -1,11 +1,14 @@
package auth
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"testing"
@@ -14,10 +17,13 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/shared/management/proto"
)
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
@@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler {
}
func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
require.NoError(t, err)
mw.domainsMux.RLock()
@@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) {
}
func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) {
}
func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
@@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
}
func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
}
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
@@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
}
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
@@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
}
func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
mw.RemoveDomain("example.com")
@@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) {
}
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
@@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
}
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
}
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
}
func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
}
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd)
@@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
@@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
@@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
}
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
@@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
}
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
@@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
}
func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
@@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
}
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
// Return a garbage token that won't validate.
@@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
handler := mw.Protect(newPassthroughHandler())
@@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
}
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
// 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise.
@@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
require.Error(t, err)
// The original valid config should still be intact.
@@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
}
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
// Scheme that always fails authentication (returns empty token)
@@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
// Submit wrong PIN - should capture auth method
@@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
}
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
// Submit wrong password - should capture auth method
@@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
}
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
@@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
capturedData := &proxy.CapturedData{}
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
// No credentials submitted - should not capture auth method
@@ -658,3 +664,271 @@ func TestWasCredentialSubmitted(t *testing.T) {
})
}
}
func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
remoteAddr string
wantCode int
}{
{"unparsable address denies", "not-an-ip:1234", http.StatusForbidden},
{"empty address denies", "", http.StatusForbidden},
{"allowed address passes", "10.1.2.3:5678", http.StatusOK},
{"denied address blocked", "192.168.1.1:5678", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = tt.remoteAddr
req.Host = "example.com"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, tt.wantCode, rr.Code)
})
}
}
func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
// When CapturedData is set (by the access log middleware, which resolves
// trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr.
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// RemoteAddr is a trusted proxy, but CapturedData has the real client IP.
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "10.0.0.1:5000"
req.Host = "example.com"
cd := proxy.NewCapturedData("")
cd.SetClientIP(netip.MustParseAddr("203.0.113.50"))
ctx := proxy.WithCapturedData(req.Context(), cd)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)")
// Same request but CapturedData has a blocked IP.
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req2.RemoteAddr = "203.0.113.50:5000"
req2.Host = "example.com"
cd2 := proxy.NewCapturedData("")
cd2.SetClientIP(netip.MustParseAddr("10.0.0.1"))
ctx2 := proxy.WithCapturedData(req2.Context(), cd2)
req2 = req2.WithContext(ctx2)
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)")
}
func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
// Geo is nil, country restrictions are configured: must deny (fail-close).
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(nil, nil, []string{"US"}, nil))
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "1.2.3.4:5678"
req.Host = "example.com"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
}
// mockAuthenticator is a minimal mock for the authenticator gRPC interface
// used by the Header scheme.
type mockAuthenticator struct {
fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error)
}
func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) {
return m.fn(ctx, in)
}
// newHeaderSchemeWithToken creates a Header scheme backed by a mock that
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && ha.GetHeaderValue() == expectedValue {
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
return &proto.AuthenticateResponse{Success: false}, nil
}}
return NewHeader(mock, "svc1", "acc1", headerName)
}
func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)
req.Header.Set("X-API-Key", "secret-key")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "ok", rec.Body.String())
// Session cookie should be set.
var sessionCookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth")
assert.True(t, sessionCookie.HttpOnly)
assert.True(t, sessionCookie.Secure)
assert.Equal(t, "header-user", capturedData.GetUserID())
assert.Equal(t, "header", capturedData.GetAuthMethod())
}
func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
// No X-API-Key header: should fall through to PIN login page (401).
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page")
}
func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("X-API-Key", "wrong-key")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "header", capturedData.GetAuthMethod())
}
func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.Header.Set("X-API-Key", "some-key")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadGateway, rec.Code)
}
func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request with header auth.
req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req1.Header.Set("X-API-Key", "secret-key")
req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData("")))
rec1 := httptest.NewRecorder()
handler.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
// Extract session cookie.
var sessionCookie *http.Cookie
for _, c := range rec1.Result().Cookies() {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie)
// Second request with only the session cookie (no header).
capturedData2 := proxy.NewCapturedData("")
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil)
req2.AddCookie(sessionCookie)
req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2))
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
assert.Equal(t, http.StatusOK, rec2.Code)
assert.Equal(t, "header-user", capturedData2.GetUserID())
assert.Equal(t, "header", capturedData2.GetAuthMethod())
}

View File

@@ -0,0 +1,264 @@
package geolocation
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
mmdbInnerName = "GeoLite2-City.mmdb"
downloadTimeout = 2 * time.Minute
maxMMDBSize = 256 << 20 // 256 MB
)
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
// it downloads from pkgs.netbird.io with SHA256 verification.
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
if err := os.MkdirAll(dataDir, 0o755); err != nil {
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
}
pattern := filepath.Join(dataDir, mmdbGlob)
if files, _ := filepath.Glob(pattern); len(files) > 0 {
mmdbPath := files[len(files)-1]
logger.Debugf("using existing geolocation database: %s", mmdbPath)
return mmdbPath, nil
}
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
return downloadMMDB(logger, dataDir)
}
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
client := &http.Client{Timeout: downloadTimeout}
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
if err != nil {
return "", fmt.Errorf("get remote filename: %w", err)
}
mmdbFilename := deriveMMDBFilename(datedName)
mmdbPath := filepath.Join(dataDir, mmdbFilename)
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
if err != nil {
return "", fmt.Errorf("create temp directory: %w", err)
}
defer os.RemoveAll(tmp)
checksumFile := filepath.Join(tmp, "checksum.sha256")
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
return "", fmt.Errorf("download checksum: %w", err)
}
expectedHash, err := readChecksumFile(checksumFile)
if err != nil {
return "", fmt.Errorf("read checksum: %w", err)
}
tarFile := filepath.Join(tmp, datedName)
logger.Debugf("downloading geolocation database (%s)", datedName)
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
return "", fmt.Errorf("download database: %w", err)
}
if err := verifySHA256(tarFile, expectedHash); err != nil {
return "", fmt.Errorf("verify database checksum: %w", err)
}
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
return "", fmt.Errorf("extract database: %w", err)
}
logger.Infof("geolocation database downloaded: %s", mmdbPath)
return mmdbPath, nil
}
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
func deriveMMDBFilename(tarName string) string {
base, _, _ := strings.Cut(tarName, ".")
if !strings.Contains(base, "_") {
return "GeoLite2-City.mmdb"
}
return base + ".mmdb"
}
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
resp, err := client.Head(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
}
cd := resp.Header.Get("Content-Disposition")
if cd == "" {
return "", errors.New("no Content-Disposition header")
}
_, params, err := mime.ParseMediaType(cd)
if err != nil {
return "", fmt.Errorf("parse Content-Disposition: %w", err)
}
name := filepath.Base(params["filename"])
if name == "" || name == "." {
return "", errors.New("no filename in Content-Disposition")
}
return name, nil
}
func downloadToFile(client *http.Client, url, dest string) error {
resp, err := client.Get(url) //nolint:gosec
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
f, err := os.Create(dest) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
return err
}
return nil
}
func readChecksumFile(path string) (string, error) {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", errors.New("empty checksum file")
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actual := fmt.Sprintf("%x", h.Sum(nil))
if actual != expected {
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
}
return nil
}
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
f, err := os.Open(tarGZPath) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gz.Close()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
}
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
return err
}
return nil
}
}
return fmt.Errorf("%s not found in archive", mmdbInnerName)
}
// extractToFileAtomic writes r to a temporary file in the same directory as
// destPath, then renames it into place so a crash never leaves a truncated file.
func extractToFileAtomic(r io.Reader, destPath string) error {
dir := filepath.Dir(destPath)
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmp.Name()
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
if closeErr := tmp.Close(); closeErr != nil {
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
}
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("write mmdb: %w", err)
}
if err := tmp.Close(); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("close temp file: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("rename to %s: %w", destPath, err)
}
return nil
}

View File

@@ -0,0 +1,152 @@
// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases.
package geolocation
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"github.com/oschwald/maxminddb-golang"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables geolocation lookups entirely when set to a truthy value.
EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION"
mmdbGlob = "GeoLite2-City_*.mmdb"
)
type record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
City struct {
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"city"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
}
// Result holds the outcome of a geo lookup.
type Result struct {
CountryCode string
CityName string
SubdivisionCode string
SubdivisionName string
}
// Lookup provides IP geolocation lookups.
type Lookup struct {
mu sync.RWMutex
db *maxminddb.Reader
logger *log.Logger
}
// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir.
// Returns nil without error if geolocation is disabled via environment
// variable, no data directory is configured, or the download fails
// (graceful degradation: country restrictions will deny all requests).
func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) {
if isDisabledByEnv(logger) {
logger.Info("geolocation disabled via environment variable")
return nil, nil //nolint:nilnil
}
if dataDir == "" {
return nil, nil //nolint:nilnil
}
mmdbPath, err := ensureMMDB(logger, dataDir)
if err != nil {
logger.Warnf("geolocation database unavailable: %v", err)
logger.Warn("country-based access restrictions will deny all requests until a database is available")
return nil, nil //nolint:nilnil
}
db, err := maxminddb.Open(mmdbPath)
if err != nil {
return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err)
}
logger.Infof("geolocation database loaded from %s", mmdbPath)
return &Lookup{db: db, logger: logger}, nil
}
// LookupAddr returns the country ISO code and city name for the given IP.
// Returns an empty Result if the database is nil or the lookup fails.
func (l *Lookup) LookupAddr(addr netip.Addr) Result {
if l == nil {
return Result{}
}
l.mu.RLock()
defer l.mu.RUnlock()
if l.db == nil {
return Result{}
}
addr = addr.Unmap()
var rec record
if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil {
l.logger.Debugf("geolocation lookup %s: %v", addr, err)
return Result{}
}
r := Result{
CountryCode: rec.Country.ISOCode,
CityName: rec.City.Names.En,
}
if len(rec.Subdivisions) > 0 {
r.SubdivisionCode = rec.Subdivisions[0].ISOCode
r.SubdivisionName = rec.Subdivisions[0].Names.En
}
return r
}
// Available reports whether the lookup has a loaded database.
func (l *Lookup) Available() bool {
if l == nil {
return false
}
l.mu.RLock()
defer l.mu.RUnlock()
return l.db != nil
}
// Close releases the database resources.
func (l *Lookup) Close() error {
if l == nil {
return nil
}
l.mu.Lock()
defer l.mu.Unlock()
if l.db != nil {
err := l.db.Close()
l.db = nil
return err
}
return nil
}
func isDisabledByEnv(logger *log.Logger) bool {
val := os.Getenv(EnvDisable)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
logger.Warnf("parse %s=%q: %v", EnvDisable, val, err)
return false
}
return disabled
}

View File

@@ -11,8 +11,6 @@ import (
type requestContextKey string
const (
serviceIdKey requestContextKey = "serviceId"
accountIdKey requestContextKey = "accountId"
capturedDataKey requestContextKey = "capturedData"
)
@@ -47,112 +45,117 @@ func (o ResponseOrigin) String() string {
// to pass data back up the middleware chain.
type CapturedData struct {
mu sync.RWMutex
RequestID string
ServiceId types.ServiceID
AccountId types.AccountID
Origin ResponseOrigin
ClientIP netip.Addr
UserID string
AuthMethod string
requestID string
serviceID types.ServiceID
accountID types.AccountID
origin ResponseOrigin
clientIP netip.Addr
userID string
authMethod string
}
// GetRequestID safely gets the request ID
// NewCapturedData creates a CapturedData with the given request ID.
func NewCapturedData(requestID string) *CapturedData {
return &CapturedData{requestID: requestID}
}
// GetRequestID returns the request ID.
func (c *CapturedData) GetRequestID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.RequestID
return c.requestID
}
// SetServiceId safely sets the service ID
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
// SetServiceID sets the service ID.
func (c *CapturedData) SetServiceID(serviceID types.ServiceID) {
c.mu.Lock()
defer c.mu.Unlock()
c.ServiceId = serviceId
c.serviceID = serviceID
}
// GetServiceId safely gets the service ID
func (c *CapturedData) GetServiceId() types.ServiceID {
// GetServiceID returns the service ID.
func (c *CapturedData) GetServiceID() types.ServiceID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ServiceId
return c.serviceID
}
// SetAccountId safely sets the account ID
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
// SetAccountID sets the account ID.
func (c *CapturedData) SetAccountID(accountID types.AccountID) {
c.mu.Lock()
defer c.mu.Unlock()
c.AccountId = accountId
c.accountID = accountID
}
// GetAccountId safely gets the account ID
func (c *CapturedData) GetAccountId() types.AccountID {
// GetAccountID returns the account ID.
func (c *CapturedData) GetAccountID() types.AccountID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AccountId
return c.accountID
}
// SetOrigin safely sets the response origin
// SetOrigin sets the response origin.
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
c.mu.Lock()
defer c.mu.Unlock()
c.Origin = origin
c.origin = origin
}
// GetOrigin safely gets the response origin
// GetOrigin returns the response origin.
func (c *CapturedData) GetOrigin() ResponseOrigin {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Origin
return c.origin
}
// SetClientIP safely sets the resolved client IP.
// SetClientIP sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
c.clientIP = ip
}
// GetClientIP safely gets the resolved client IP.
// GetClientIP returns the resolved client IP.
func (c *CapturedData) GetClientIP() netip.Addr {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
return c.clientIP
}
// SetUserID safely sets the authenticated user ID.
// SetUserID sets the authenticated user ID.
func (c *CapturedData) SetUserID(userID string) {
c.mu.Lock()
defer c.mu.Unlock()
c.UserID = userID
c.userID = userID
}
// GetUserID safely gets the authenticated user ID.
// GetUserID returns the authenticated user ID.
func (c *CapturedData) GetUserID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.UserID
return c.userID
}
// SetAuthMethod safely sets the authentication method used.
// SetAuthMethod sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()
defer c.mu.Unlock()
c.AuthMethod = method
c.authMethod = method
}
// GetAuthMethod safely gets the authentication method used.
// GetAuthMethod returns the authentication method used.
func (c *CapturedData) GetAuthMethod() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AuthMethod
return c.authMethod
}
// WithCapturedData adds a CapturedData struct to the context
// WithCapturedData adds a CapturedData struct to the context.
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
return context.WithValue(ctx, capturedDataKey, data)
}
// CapturedDataFromContext retrieves the CapturedData from context
// CapturedDataFromContext retrieves the CapturedData from context.
func CapturedDataFromContext(ctx context.Context) *CapturedData {
v := ctx.Value(capturedDataKey)
data, ok := v.(*CapturedData)
@@ -161,28 +164,3 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
}
return data
}
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(types.ServiceID)
if !ok {
return ""
}
return serviceId
}
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
return context.WithValue(ctx, accountIdKey, accountId)
}
func AccountIdFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIdKey)
accountId, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountId
}

View File

@@ -66,19 +66,16 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), result.serviceID)
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx := r.Context()
// Set the account ID in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// Populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(result.serviceID)
capturedData.SetAccountId(result.accountID)
capturedData.SetServiceID(result.serviceID)
capturedData.SetAccountID(result.accountID)
}
pt := result.target
@@ -96,10 +93,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders),
Transport: p.transport,
FlushInterval: -1,
ErrorHandler: proxyErrorHandler,
ErrorHandler: p.proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
@@ -113,7 +110,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
// The pathRewrite parameter controls how the request path is transformed.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string, stripAuthHeaders []string) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
switch pathRewrite {
case PathRewritePreserve:
@@ -137,6 +134,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Host = target.Host
}
for _, h := range stripAuthHeaders {
r.Out.Header.Del(h)
}
for k, v := range customHeaders {
r.Out.Header.Set(k, v)
}
@@ -305,7 +306,7 @@ func extractForwardedPort(host, resolvedProto string) string {
// proxyErrorHandler handles errors from the reverse proxy and serves
// user-friendly error pages instead of raw error responses.
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
func (p *ReverseProxy) proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginProxyError)
}
@@ -313,7 +314,7 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
clientIP := getClientIP(r)
title, message, code, status := classifyProxyError(err)
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
p.logger.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
web.ServeErrorPage(w, r, code, title, message, requestID, status)

View File

@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr)
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -551,7 +551,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
t.Run("preserve keeps full request path", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil, nil)
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
rewrite(pr)
@@ -561,7 +561,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
})
t.Run("preserve with root matchedPath", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil, nil)
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
rewrite(pr)
@@ -579,7 +579,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
"X-Custom-Auth": "token-abc",
"X-Env": "production",
}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -589,7 +589,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
})
t.Run("nil customHeaders is fine", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -599,7 +599,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
t.Run("custom headers override existing request headers", func(t *testing.T) {
headers := map[string]string{"X-Override": "new-value"}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("X-Override", "old-value")
@@ -609,11 +609,38 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
})
}
func TestRewriteFunc_StripsAuthorizationHeader(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("strips incoming Authorization when no custom Authorization set", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, []string{"Authorization"})
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("Authorization", "Bearer proxy-token")
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("Authorization"), "Authorization should be stripped")
})
t.Run("custom Authorization replaces incoming", func(t *testing.T) {
headers := map[string]string{"Authorization": "Basic YmFja2VuZDpzZWNyZXQ="}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, []string{"Authorization"})
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("Authorization", "Bearer proxy-token")
rewrite(pr)
assert.Equal(t, "Basic YmFja2VuZDpzZWNyZXQ=", pr.Out.Header.Get("Authorization"),
"backend Authorization from custom headers should be set")
})
}
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}, nil)
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
rewrite(pr)

View File

@@ -38,6 +38,11 @@ type Mapping struct {
Paths map[string]*PathTarget
PassHostHeader bool
RewriteRedirects bool
// StripAuthHeaders are header names used for header-based auth.
// These headers are stripped from requests before forwarding.
StripAuthHeaders []string
// sortedPaths caches the paths sorted by length (longest first).
sortedPaths []string
}
type targetResult struct {
@@ -47,6 +52,7 @@ type targetResult struct {
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
stripAuthHeaders []string
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
@@ -65,16 +71,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
return targetResult{}, false
}
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
for _, path := range paths {
for _, path := range m.sortedPaths {
if strings.HasPrefix(req.URL.Path, path) {
pt := m.Paths[path]
if pt == nil || pt.URL == nil {
@@ -89,6 +86,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
rewriteRedirects: m.RewriteRedirects,
stripAuthHeaders: m.StripAuthHeaders,
}, true
}
}
@@ -96,7 +94,18 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
return targetResult{}, false
}
// AddMapping registers a host-to-backend mapping for the reverse proxy.
func (p *ReverseProxy) AddMapping(m Mapping) {
// Sort paths longest-first to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
m.sortedPaths = paths
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
p.mappings[m.Host] = m

View File

@@ -0,0 +1,183 @@
// Package restrict provides connection-level access control based on
// IP CIDR ranges and geolocation (country codes).
package restrict
import (
"net/netip"
"slices"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
)
// GeoResolver resolves an IP address to geographic information.
type GeoResolver interface {
LookupAddr(addr netip.Addr) geolocation.Result
Available() bool
}
// Filter evaluates IP restrictions. CIDR checks are performed first
// (cheap), followed by country lookups (more expensive) only when needed.
type Filter struct {
AllowedCIDRs []netip.Prefix
BlockedCIDRs []netip.Prefix
AllowedCountries []string
BlockedCountries []string
}
// ParseFilter builds a Filter from the raw string slices. Returns nil
// if all slices are empty.
func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter {
if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 &&
len(allowedCountries) == 0 && len(blockedCountries) == 0 {
return nil
}
f := &Filter{
AllowedCountries: normalizeCountryCodes(allowedCountries),
BlockedCountries: normalizeCountryCodes(blockedCountries),
}
for _, cidr := range allowedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err)
continue
}
f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked())
}
for _, cidr := range blockedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err)
continue
}
f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked())
}
return f
}
func normalizeCountryCodes(codes []string) []string {
if len(codes) == 0 {
return nil
}
out := make([]string, len(codes))
for i, c := range codes {
out[i] = strings.ToUpper(c)
}
return out
}
// Verdict is the result of an access check.
type Verdict int
const (
// Allow indicates the address passed all checks.
Allow Verdict = iota
// DenyCIDR indicates the address was blocked by a CIDR rule.
DenyCIDR
// DenyCountry indicates the address was blocked by a country rule.
DenyCountry
// DenyGeoUnavailable indicates that country restrictions are configured
// but the geo lookup is unavailable.
DenyGeoUnavailable
)
// String returns the deny reason string matching the HTTP auth mechanism names.
func (v Verdict) String() string {
switch v {
case Allow:
return "allow"
case DenyCIDR:
return "ip_restricted"
case DenyCountry:
return "country_restricted"
case DenyGeoUnavailable:
return "geo_unavailable"
default:
return "unknown"
}
}
// Check evaluates whether addr is permitted. CIDR rules are evaluated
// first because they are O(n) prefix comparisons. Country rules run
// only when CIDR checks pass and require a geo lookup.
func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict {
if f == nil {
return Allow
}
// Normalize v4-mapped-v6 (e.g. ::ffff:10.1.2.3) to plain v4 so that
// IPv4 CIDR rules match regardless of how the address was received.
addr = addr.Unmap()
if v := f.checkCIDR(addr); v != Allow {
return v
}
return f.checkCountry(addr, geo)
}
func (f *Filter) checkCIDR(addr netip.Addr) Verdict {
if len(f.AllowedCIDRs) > 0 {
allowed := false
for _, prefix := range f.AllowedCIDRs {
if prefix.Contains(addr) {
allowed = true
break
}
}
if !allowed {
return DenyCIDR
}
}
for _, prefix := range f.BlockedCIDRs {
if prefix.Contains(addr) {
return DenyCIDR
}
}
return Allow
}
func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict {
if len(f.AllowedCountries) == 0 && len(f.BlockedCountries) == 0 {
return Allow
}
if geo == nil || !geo.Available() {
return DenyGeoUnavailable
}
result := geo.LookupAddr(addr)
if result.CountryCode == "" {
// Unknown country: deny if an allowlist is active, allow otherwise.
// Blocklists are best-effort: unknown countries pass through since
// the default policy is allow.
if len(f.AllowedCountries) > 0 {
return DenyCountry
}
return Allow
}
if len(f.AllowedCountries) > 0 {
if !slices.Contains(f.AllowedCountries, result.CountryCode) {
return DenyCountry
}
}
if slices.Contains(f.BlockedCountries, result.CountryCode) {
return DenyCountry
}
return Allow
}
// HasRestrictions returns true if any restriction rules are configured.
func (f *Filter) HasRestrictions() bool {
if f == nil {
return false
}
return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 ||
len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0
}

View File

@@ -0,0 +1,278 @@
package restrict
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
)
type mockGeo struct {
countries map[string]string
}
func (m *mockGeo) LookupAddr(addr netip.Addr) geolocation.Result {
return geolocation.Result{CountryCode: m.countries[addr.String()]}
}
func (m *mockGeo) Available() bool { return true }
func newMockGeo(entries map[string]string) *mockGeo {
return &mockGeo{countries: entries}
}
func TestFilter_Check_NilFilter(t *testing.T) {
var f *Filter
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
}
func TestFilter_Check_AllowedCIDR(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_BlockedCIDR(t *testing.T) {
f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil)
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "not in allowlist")
}
func TestFilter_Check_AllowedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
f := ParseFilter(nil, nil, []string{"US", "DE"}, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
}
func TestFilter_Check_BlockedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "CN",
"2.2.2.2": "RU",
"3.3.3.3": "US",
})
f := ParseFilter(nil, nil, nil, []string{"CN", "RU"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "US not in blocklist")
}
func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
// Allow US and DE, but block DE explicitly.
f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"})
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
}
func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
})
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist")
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active")
}
func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "CN",
})
f := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active")
}
func TestFilter_Check_CountryWithoutGeo(t *testing.T) {
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist")
}
func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) {
f := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist")
}
func TestFilter_Check_GeoUnavailable(t *testing.T) {
geo := &unavailableGeo{}
f := ParseFilter(nil, nil, []string{"US"}, nil)
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist")
f2 := ParseFilter(nil, nil, nil, []string{"CN"})
assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist")
}
func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
// CIDR-only filter should never touch geo, so nil geo is fine.
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
}
func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) {
geo := newMockGeo(map[string]string{
"10.1.2.3": "CN",
"10.2.3.4": "US",
})
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"})
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), geo), "CIDR denied before country check")
}
func TestParseFilter_Empty(t *testing.T) {
f := ParseFilter(nil, nil, nil, nil)
assert.Nil(t, f)
}
func TestParseFilter_InvalidCIDR(t *testing.T) {
f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil)
assert.NotNil(t, f)
assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped")
assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), f.AllowedCIDRs[0])
}
func TestFilter_HasRestrictions(t *testing.T) {
assert.False(t, (*Filter)(nil).HasRestrictions())
assert.False(t, (&Filter{}).HasRestrictions())
assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions())
assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions())
}
func TestFilter_Check_IPv6CIDR(t *testing.T) {
f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 addr not in v6 allowlist")
}
func TestFilter_Check_IPv4MappedIPv6(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
// A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR.
v4mapped := netip.MustParseAddr("::ffff:10.1.2.3")
assert.True(t, v4mapped.Is4In6(), "precondition: address is v4-in-v6")
assert.Equal(t, Allow, f.Check(v4mapped, nil), "v4-mapped-v6 must match v4 CIDR after Unmap")
v4mappedOutside := netip.MustParseAddr("::ffff:192.168.1.1")
assert.Equal(t, DenyCIDR, f.Check(v4mappedOutside, nil), "v4-mapped-v6 outside v4 CIDR")
}
func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) {
f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil)
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR")
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("::ffff:10.1.2.3"), nil), "v4-mapped matches v4 CIDR")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "v4 not in either CIDR")
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("fe80::1"), nil), "v6 not in either CIDR")
}
func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) {
// 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24.
f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil)
assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0])
// Verify it still matches correctly.
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.100"), nil))
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("1.1.2.1"), nil))
}
func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) {
geo := newMockGeo(map[string]string{
"1.1.1.1": "US",
"2.2.2.2": "DE",
"3.3.3.3": "CN",
})
tests := []struct {
name string
allowedCountries []string
blockedCountries []string
addr string
want Verdict
}{
{
name: "lowercase allowlist matches uppercase MaxMind code",
allowedCountries: []string{"us", "de"},
addr: "1.1.1.1",
want: Allow,
},
{
name: "mixed-case allowlist matches",
allowedCountries: []string{"Us", "dE"},
addr: "2.2.2.2",
want: Allow,
},
{
name: "lowercase allowlist rejects non-matching country",
allowedCountries: []string{"us", "de"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "lowercase blocklist blocks matching country",
blockedCountries: []string{"cn"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "mixed-case blocklist blocks matching country",
blockedCountries: []string{"Cn"},
addr: "3.3.3.3",
want: DenyCountry,
},
{
name: "lowercase blocklist does not block non-matching country",
blockedCountries: []string{"cn"},
addr: "1.1.1.1",
want: Allow,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries)
got := f.Check(netip.MustParseAddr(tc.addr), geo)
assert.Equal(t, tc.want, got)
})
}
}
// unavailableGeo simulates a GeoResolver whose database is not loaded.
type unavailableGeo struct{}
func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} }
func (u *unavailableGeo) Available() bool { return false }

View File

@@ -7,12 +7,14 @@ import (
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -20,6 +22,10 @@ import (
// timeout is configured.
const defaultDialTimeout = 30 * time.Second
// errAccessRestricted is returned by relayTCP for access restriction
// denials so callers can skip warn-level logging (already logged at debug).
var errAccessRestricted = errors.New("rejected by access restrictions")
// SNIHost is a typed key for SNI hostname lookups.
type SNIHost string
@@ -64,6 +70,11 @@ type Route struct {
// DialTimeout overrides the default dial timeout for this route.
// Zero uses defaultDialTimeout.
DialTimeout time.Duration
// SessionIdleTimeout overrides the default idle timeout for relay connections.
// Zero uses DefaultIdleTimeout.
SessionIdleTimeout time.Duration
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
}
// l4Logger sends layer-4 access log entries to the management server.
@@ -99,6 +110,7 @@ type Router struct {
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context
@@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener {
// stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
func (r *Router) AddRoute(host SNIHost, route Route) {
host = SNIHost(strings.ToLower(string(host)))
if host == "" {
return
}
@@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) {
// Active relay connections for the service are closed immediately.
// If other routes remain for the host, they are preserved.
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
host = SNIHost(strings.ToLower(string(host)))
r.mu.Lock()
defer r.mu.Unlock()
@@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
return
}
host := SNIHost(sni)
host := SNIHost(strings.ToLower(sni))
route, ok := r.lookupRoute(host)
if !ok {
r.handleUnmatched(ctx, wrapped)
@@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
}
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
}
_ = wrapped.Close()
}
}
@@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
if fb != nil {
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
}
_ = conn.Close()
}
return
@@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
}
}
// SetGeo sets the geolocation lookup used for country-based restrictions.
func (r *Router) SetGeo(geo restrict.GeoResolver) {
r.mu.Lock()
defer r.mu.Unlock()
r.geo = geo
}
// checkRestrictions evaluates the route's access filter against the
// connection's remote address. Returns Allow if the connection is
// permitted, or a deny verdict indicating the reason.
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
if route.Filter == nil {
return restrict.Allow
}
addr, err := addrFromConn(conn)
if err != nil {
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
return restrict.DenyCIDR
}
r.mu.RLock()
geo := r.geo
r.mu.RUnlock()
return route.Filter.Check(addr, geo)
}
// relayTCP sets up and runs a bidirectional TCP relay.
// The caller owns conn and must close it if this method returns an error.
// On success (nil error), both conn and backend are closed by the relay.
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
r.logL4Deny(route, conn, verdict)
return errAccessRestricted
}
svcCtx, err := r.acquireRelay(ctx, route)
if err != nil {
return err
@@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
})
entry.Debug("TCP relay started")
idleTimeout := route.SessionIdleTimeout
if idleTimeout <= 0 {
idleTimeout = DefaultIdleTimeout
}
start := time.Now()
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
elapsed := time.Since(start)
if obs != nil {
@@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
return
}
var sourceIP netip.Addr
if remote := conn.RemoteAddr(); remote != nil {
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
@@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
})
}
// logL4Deny sends an access log entry for a denied connection.
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
r.mu.RLock()
al := r.accessLog
r.mu.RUnlock()
if al == nil {
return
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
ServiceID: route.ServiceID,
Protocol: route.Protocol,
Host: route.Domain,
SourceIP: sourceIP,
DenyReason: verdict.String(),
})
}
// getOrCreateServiceCtxLocked returns the context for a service, creating one
// if it doesn't exist yet. The context is a child of the server context.
// Must be called with mu held.
@@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types
r.svcCancels[svcID] = cancel
return ctx
}
// addrFromConn extracts a netip.Addr from a connection's remote address.
func addrFromConn(conn net.Conn) (netip.Addr, error) {
remote := conn.RemoteAddr()
if remote == nil {
return netip.Addr{}, errors.New("no remote address")
}
ap, err := netip.ParseAddrPort(remote.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -1668,3 +1669,73 @@ func startEchoPlain(t *testing.T) net.Listener {
return ln
}
// fakeAddr implements net.Addr with a custom string representation.
type fakeAddr string
func (f fakeAddr) Network() string { return "tcp" }
func (f fakeAddr) String() string { return string(f) }
// fakeConn is a minimal net.Conn with a controllable RemoteAddr.
type fakeConn struct {
net.Conn
remote net.Addr
}
func (f *fakeConn) RemoteAddr() net.Addr { return f.remote }
func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied")
}
func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: nil}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied")
}
func TestCheckRestrictions_AllowedAndDenied(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist")
denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist")
}
func TestCheckRestrictions_NilFilter(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
route := Route{Filter: nil}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything")
}
func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
// net.IPv4() returns a 16-byte v4-in-v6 representation internally.
// The restriction check must Unmap it to match the v4 CIDR.
conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR")
// Explicitly v4-mapped-v6 address string.
conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR")
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -67,6 +68,8 @@ type Relay struct {
dialTimeout time.Duration
sessionTTL time.Duration
maxSessions int
filter *restrict.Filter
geo restrict.GeoResolver
mu sync.RWMutex
sessions map[clientAddr]*session
@@ -114,6 +117,10 @@ type RelayConfig struct {
SessionTTL time.Duration
MaxSessions int
AccessLog l4Logger
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
// Geo is the geolocation lookup used for country-based restrictions.
Geo restrict.GeoResolver
}
// New creates a UDP relay for the given listener and backend target.
@@ -146,6 +153,8 @@ func New(parentCtx context.Context, cfg RelayConfig) *Relay {
dialTimeout: dialTimeout,
sessionTTL: sessionTTL,
maxSessions: maxSessions,
filter: cfg.Filter,
geo: cfg.Geo,
sessions: make(map[clientAddr]*session),
bufPool: sync.Pool{
New: func() any {
@@ -166,9 +175,18 @@ func (r *Relay) ServiceID() types.ServiceID {
// SetObserver sets the session lifecycle observer. Must be called before Serve.
func (r *Relay) SetObserver(obs SessionObserver) {
r.mu.Lock()
defer r.mu.Unlock()
r.observer = obs
}
// getObserver returns the current session lifecycle observer.
func (r *Relay) getObserver() SessionObserver {
r.mu.RLock()
defer r.mu.RUnlock()
return r.observer
}
// Serve starts the relay loop. It blocks until the context is canceled
// or the listener is closed.
func (r *Relay) Serve() {
@@ -209,8 +227,8 @@ func (r *Relay) Serve() {
}
sess.bytesIn.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
if obs := r.getObserver(); obs != nil {
obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
}
r.bufPool.Put(bufp)
}
@@ -234,6 +252,10 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
return nil, r.ctx.Err()
}
if err := r.checkAccessRestrictions(addr); err != nil {
return nil, err
}
r.mu.Lock()
if sess, ok = r.sessions[key]; ok && sess != nil {
@@ -248,16 +270,16 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
if len(r.sessions) >= r.maxSessions {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
}
if !r.sessLimiter.Allow() {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session creation rate limited")
}
@@ -274,8 +296,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
r.mu.Lock()
delete(r.sessions, key)
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionDialError(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionDialError(r.accountID)
}
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
}
@@ -293,8 +315,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
r.sessions[key] = sess
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionStarted(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionStarted(r.accountID)
}
r.sessWg.Go(func() {
@@ -305,6 +327,21 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
return sess, nil
}
func (r *Relay) checkAccessRestrictions(addr net.Addr) error {
if r.filter == nil {
return nil
}
clientIP, err := addrFromUDPAddr(addr)
if err != nil {
return fmt.Errorf("parse client address %s for restriction check: %w", addr, err)
}
if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow {
r.logDeny(clientIP, v)
return fmt.Errorf("access restricted for %s", addr)
}
return nil
}
// relayBackendToClient reads packets from the backend and writes them
// back to the client through the public-facing listener.
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
@@ -332,8 +369,8 @@ func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
}
sess.bytesOut.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
if obs := r.getObserver(); obs != nil {
obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
}
}
}
@@ -402,9 +439,10 @@ func (r *Relay) cleanupIdleSessions() {
}
r.mu.Unlock()
obs := r.getObserver()
for _, sess := range expired {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
if obs != nil {
obs.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
@@ -429,8 +467,8 @@ func (r *Relay) removeSession(sess *session) {
if removed {
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
if obs := r.getObserver(); obs != nil {
obs.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
@@ -459,6 +497,22 @@ func (r *Relay) logSessionEnd(sess *session) {
})
}
// logDeny sends an access log entry for a denied UDP packet.
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) {
if r.accessLog == nil {
return
}
r.accessLog.LogL4(accesslog.L4Entry{
AccountID: r.accountID,
ServiceID: r.serviceID,
Protocol: accesslog.ProtocolUDP,
Host: r.domain,
SourceIP: clientIP,
DenyReason: verdict.String(),
})
}
// Close stops the relay, waits for all session goroutines to exit,
// and cleans up remaining sessions.
func (r *Relay) Close() {
@@ -485,12 +539,22 @@ func (r *Relay) Close() {
}
r.mu.Unlock()
obs := r.getObserver()
for _, sess := range closedSessions {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
if obs != nil {
obs.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
r.sessWg.Wait()
}
// addrFromUDPAddr extracts a netip.Addr from a net.Addr.
func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) {
ap, err := netip.ParseAddrPort(addr.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}

View File

@@ -490,7 +490,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
logger := log.New()
logger.SetLevel(log.WarnLevel)
authMw := auth.NewMiddleware(logger, nil)
authMw := auth.NewMiddleware(logger, nil, nil)
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
clusterAddress := "test.proxy.io"
@@ -511,6 +511,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
0,
proxytypes.AccountID(mapping.GetAccountId()),
proxytypes.ServiceID(mapping.GetId()),
nil,
)
require.NoError(t, err)

View File

@@ -43,12 +43,14 @@ import (
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s"
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -59,7 +61,6 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// portRouter bundles a per-port Router with its listener and cancel func.
type portRouter struct {
router *nbtcp.Router
@@ -95,6 +96,9 @@ type Server struct {
// so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them.
hijackTracker conntrack.HijackTracker
// geo resolves IP addresses to country/city for access restrictions and access logs.
geo restrict.GeoResolver
geoRaw *geolocation.Lookup
// routerReady is closed once mainRouter is fully initialized.
// The mapping worker waits on this before processing updates.
@@ -159,10 +163,38 @@ type Server struct {
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
// ports for TCP/UDP/TLS services.
SupportsCustomPorts bool
// DefaultDialTimeout is the default timeout for establishing backend
// connections when no per-service timeout is configured. Zero means
// each transport uses its own hardcoded default (typically 30s).
DefaultDialTimeout time.Duration
// MaxDialTimeout caps the per-service backend dial timeout.
// When the API sends a timeout, it is clamped to this value.
// When the API sends no timeout, this value is used as the default.
// Zero means no cap (the proxy honors whatever management sends).
MaxDialTimeout time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files for
// country-based access restrictions. Empty disables geo lookups.
GeoDataDir string
// MaxSessionIdleTimeout caps the per-service session idle timeout.
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
return s.MaxSessionIdleTimeout
}
return d
}
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
// If d is zero, MaxDialTimeout is used as the default.
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
if s.MaxDialTimeout <= 0 {
return d
}
if d <= 0 || d > s.MaxDialTimeout {
return s.MaxDialTimeout
}
return d
}
// NotifyStatus sends a status update to management about tunnel connectivity.
@@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
@@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
tlsConfig, err := s.configureTLS(ctx)
if err != nil {
return err
@@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
var startupOK bool
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.startDebugEndpoint()
if err := s.startHealthServer(); err != nil {
@@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
}
startupOK = true
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
@@ -691,6 +749,16 @@ func (s *Server) shutdownServices() {
s.portRouterWg.Wait()
wg.Wait()
if s.accessLog != nil {
s.accessLog.Close()
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation: %v", err)
}
}
}
// resolveDialFunc returns a DialContextFunc that dials through the
@@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TCP port %d: %w", port, err)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.SetFallback(nbtcp.Route{
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
})
s.portMu.Lock()
@@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("empty target address for UDP service %s", svcID)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
}
@@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
})
if tlsPort != s.mainPort {
@@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
}
}
// parseRestrictions converts a proto mapping's access restrictions into
// a restrict.Filter. Returns nil if the mapping has no restrictions.
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
r := mapping.GetAccessRestrictions()
if r == nil {
return nil
}
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
}
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
// but the proxy has no geolocation database loaded. All requests to this
// service will be denied at runtime (fail-close).
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
if r == nil {
return
}
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
return
}
if s.geo != nil && s.geo.Available() {
return
}
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
}
// l4TargetAddress extracts and validates the target address from a mapping's
// first path entry. Returns empty string if no paths exist or the address is
// not a valid host:port.
@@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
}
// l4DialTimeout returns the dial timeout from the first target's options,
// falling back to the server's DefaultDialTimeout.
// clamped to MaxDialTimeout.
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
paths := mapping.GetPath()
if len(paths) > 0 {
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
return d.AsDuration()
return s.clampDialTimeout(d.AsDuration())
}
}
return s.DefaultDialTimeout
return s.clampDialTimeout(0)
}
// l4SessionIdleTimeout returns the configured session idle timeout from the
@@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
dialFn, err := s.resolveDialFunc(accountID)
if err != nil {
_ = listener.Close()
if err := listener.Close(); err != nil {
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
}
return fmt.Errorf("resolve dialer for UDP: %w", err)
}
@@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
ServiceID: svcID,
DialFunc: dialFn,
DialTimeout: s.l4DialTimeout(mapping),
SessionTTL: l4SessionIdleTimeout(mapping),
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
AccessLog: s.accessLog,
Filter: parseRestrictions(mapping),
Geo: s.geo,
})
relay.SetObserver(s.meter)
@@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
if mapping.GetAuth().GetOidc() {
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
}
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
}
ipRestrictions := parseRestrictions(mapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil {
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
}
m := s.protoToMapping(ctx, mapping)
@@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
pt.RequestTimeout = d.AsDuration()
}
}
if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 {
pt.RequestTimeout = s.DefaultDialTimeout
}
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
paths[pathMapping.GetPath()] = pt
}
return proxy.Mapping{
m := proxy.Mapping{
ID: types.ServiceID(mapping.GetId()),
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
@@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
PassHostHeader: mapping.GetPassHostHeader(),
RewriteRedirects: mapping.GetRewriteRedirects(),
}
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
}
return m
}
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {