mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
@@ -10,7 +10,7 @@ FROM gcr.io/distroless/base:debug
|
||||
COPY netbird-proxy /go/bin/netbird-proxy
|
||||
COPY --from=builder /tmp/passwd /etc/passwd
|
||||
COPY --from=builder /tmp/group /etc/group
|
||||
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
|
||||
USER netbird:netbird
|
||||
ENV HOME=/var/lib/netbird
|
||||
|
||||
@@ -28,7 +28,7 @@ FROM gcr.io/distroless/base:debug
|
||||
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
|
||||
COPY --from=builder /tmp/passwd /etc/passwd
|
||||
COPY --from=builder /tmp/group /etc/group
|
||||
COPY --from=builder /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
|
||||
COPY --from=builder --chown=1000:1000 --chmod=755 /tmp/certs /certs
|
||||
USER netbird:netbird
|
||||
ENV HOME=/var/lib/netbird
|
||||
|
||||
@@ -13,10 +13,11 @@ import (
|
||||
|
||||
type Method string
|
||||
|
||||
var (
|
||||
const (
|
||||
MethodPassword Method = "password"
|
||||
MethodPIN Method = "pin"
|
||||
MethodOIDC Method = "oidc"
|
||||
MethodHeader Method = "header"
|
||||
)
|
||||
|
||||
func (m Method) String() string {
|
||||
|
||||
@@ -36,31 +36,33 @@ var (
|
||||
|
||||
var (
|
||||
logLevel string
|
||||
debugLogs bool
|
||||
mgmtAddr string
|
||||
addr string
|
||||
proxyDomain string
|
||||
defaultDialTimeout time.Duration
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeEABKID string
|
||||
acmeEABHMACKey string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wildcardCertDir string
|
||||
wgPort uint16
|
||||
proxyProtocol bool
|
||||
preSharedKey string
|
||||
supportsCustomPorts bool
|
||||
debugLogs bool
|
||||
mgmtAddr string
|
||||
addr string
|
||||
proxyDomain string
|
||||
maxDialTimeout time.Duration
|
||||
maxSessionIdleTimeout time.Duration
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeEABKID string
|
||||
acmeEABHMACKey string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wildcardCertDir string
|
||||
wgPort uint16
|
||||
proxyProtocol bool
|
||||
preSharedKey string
|
||||
supportsCustomPorts bool
|
||||
geoDataDir string
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -99,7 +101,9 @@ func init() {
|
||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
|
||||
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
|
||||
rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)")
|
||||
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
|
||||
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
|
||||
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -177,17 +181,15 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
ProxyProtocol: proxyProtocol,
|
||||
PreSharedKey: preSharedKey,
|
||||
SupportsCustomPorts: supportsCustomPorts,
|
||||
DefaultDialTimeout: defaultDialTimeout,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
GeoDataDir: geoDataDir,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
if err := srv.ListenAndServe(ctx, addr); err != nil {
|
||||
logger.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return srv.ListenAndServe(ctx, addr)
|
||||
}
|
||||
|
||||
func envBoolOrDefault(key string, def bool) bool {
|
||||
@@ -197,6 +199,7 @@ func envBoolOrDefault(key string, def bool) bool {
|
||||
}
|
||||
parsed, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s=%q: %v, using default %v", key, v, err, def)
|
||||
return def
|
||||
}
|
||||
return parsed
|
||||
@@ -217,6 +220,7 @@ func envUint16OrDefault(key string, def uint16) uint16 {
|
||||
}
|
||||
parsed, err := strconv.ParseUint(v, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s=%q: %v, using default %d", key, v, err, def)
|
||||
return def
|
||||
}
|
||||
return uint16(parsed)
|
||||
@@ -229,6 +233,7 @@ func envDurationOrDefault(key string, def time.Duration) time.Duration {
|
||||
}
|
||||
parsed, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s=%q: %v, using default %s", key, v, err, def)
|
||||
return def
|
||||
}
|
||||
return parsed
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -22,6 +23,16 @@ const (
|
||||
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
|
||||
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
|
||||
logSendTimeout = 10 * time.Second
|
||||
|
||||
// denyCooldown is the min interval between deny log entries per service+reason
|
||||
// to prevent flooding from denied connections (e.g. UDP packets from blocked IPs).
|
||||
denyCooldown = 10 * time.Second
|
||||
|
||||
// maxDenyBuckets caps tracked deny rate-limit entries to bound memory under DDoS.
|
||||
maxDenyBuckets = 10000
|
||||
|
||||
// maxLogWorkers caps concurrent gRPC send goroutines.
|
||||
maxLogWorkers = 4096
|
||||
)
|
||||
|
||||
type domainUsage struct {
|
||||
@@ -38,6 +49,18 @@ type gRPCClient interface {
|
||||
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
|
||||
}
|
||||
|
||||
// denyBucketKey identifies a rate-limited deny log stream.
|
||||
type denyBucketKey struct {
|
||||
ServiceID types.ServiceID
|
||||
Reason string
|
||||
}
|
||||
|
||||
// denyBucket tracks rate-limited deny log entries.
|
||||
type denyBucket struct {
|
||||
lastLogged time.Time
|
||||
suppressed int64
|
||||
}
|
||||
|
||||
// Logger sends access log entries to the management server via gRPC.
|
||||
type Logger struct {
|
||||
client gRPCClient
|
||||
@@ -47,7 +70,12 @@ type Logger struct {
|
||||
usageMux sync.Mutex
|
||||
domainUsage map[string]*domainUsage
|
||||
|
||||
denyMu sync.Mutex
|
||||
denyBuckets map[denyBucketKey]*denyBucket
|
||||
|
||||
logSem chan struct{}
|
||||
cleanupCancel context.CancelFunc
|
||||
dropped atomic.Int64
|
||||
}
|
||||
|
||||
// NewLogger creates a new access log Logger. The trustedProxies parameter
|
||||
@@ -64,6 +92,8 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
|
||||
logger: logger,
|
||||
trustedProxies: trustedProxies,
|
||||
domainUsage: make(map[string]*domainUsage),
|
||||
denyBuckets: make(map[denyBucketKey]*denyBucket),
|
||||
logSem: make(chan struct{}, maxLogWorkers),
|
||||
cleanupCancel: cancel,
|
||||
}
|
||||
|
||||
@@ -83,7 +113,7 @@ func (l *Logger) Close() {
|
||||
type logEntry struct {
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
ServiceId types.ServiceID
|
||||
ServiceID types.ServiceID
|
||||
Host string
|
||||
Path string
|
||||
DurationMs int64
|
||||
@@ -91,7 +121,7 @@ type logEntry struct {
|
||||
ResponseCode int32
|
||||
SourceIP netip.Addr
|
||||
AuthMechanism string
|
||||
UserId string
|
||||
UserID string
|
||||
AuthSuccess bool
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
@@ -118,6 +148,10 @@ type L4Entry struct {
|
||||
DurationMs int64
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
// DenyReason, when non-empty, indicates the connection was denied.
|
||||
// Values match the HTTP auth mechanism strings: "ip_restricted",
|
||||
// "country_restricted", "geo_unavailable".
|
||||
DenyReason string
|
||||
}
|
||||
|
||||
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
|
||||
@@ -126,7 +160,7 @@ func (l *Logger) LogL4(entry L4Entry) {
|
||||
le := logEntry{
|
||||
ID: xid.New().String(),
|
||||
AccountID: entry.AccountID,
|
||||
ServiceId: entry.ServiceID,
|
||||
ServiceID: entry.ServiceID,
|
||||
Protocol: entry.Protocol,
|
||||
Host: entry.Host,
|
||||
SourceIP: entry.SourceIP,
|
||||
@@ -134,10 +168,47 @@ func (l *Logger) LogL4(entry L4Entry) {
|
||||
BytesUpload: entry.BytesUpload,
|
||||
BytesDownload: entry.BytesDownload,
|
||||
}
|
||||
if entry.DenyReason != "" {
|
||||
if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) {
|
||||
return
|
||||
}
|
||||
le.AuthMechanism = entry.DenyReason
|
||||
le.AuthSuccess = false
|
||||
}
|
||||
l.log(le)
|
||||
l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload)
|
||||
}
|
||||
|
||||
// allowDenyLog rate-limits deny log entries per service+reason combination.
|
||||
func (l *Logger) allowDenyLog(serviceID types.ServiceID, reason string) bool {
|
||||
key := denyBucketKey{ServiceID: serviceID, Reason: reason}
|
||||
now := time.Now()
|
||||
|
||||
l.denyMu.Lock()
|
||||
defer l.denyMu.Unlock()
|
||||
|
||||
b, ok := l.denyBuckets[key]
|
||||
if !ok {
|
||||
if len(l.denyBuckets) >= maxDenyBuckets {
|
||||
return false
|
||||
}
|
||||
l.denyBuckets[key] = &denyBucket{lastLogged: now}
|
||||
return true
|
||||
}
|
||||
|
||||
if now.Sub(b.lastLogged) >= denyCooldown {
|
||||
if b.suppressed > 0 {
|
||||
l.logger.Debugf("access restriction: suppressed %d deny log entries for %s (%s)", b.suppressed, serviceID, reason)
|
||||
}
|
||||
b.lastLogged = now
|
||||
b.suppressed = 0
|
||||
return true
|
||||
}
|
||||
|
||||
b.suppressed++
|
||||
return false
|
||||
}
|
||||
|
||||
func (l *Logger) log(entry logEntry) {
|
||||
// Fire off the log request in a separate routine.
|
||||
// This increases the possibility of losing a log message
|
||||
@@ -147,12 +218,21 @@ func (l *Logger) log(entry logEntry) {
|
||||
// There is also a chance that log messages will arrive at
|
||||
// the server out of order; however, the timestamp should
|
||||
// allow for resolving that on the server.
|
||||
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
|
||||
now := timestamppb.Now()
|
||||
select {
|
||||
case l.logSem <- struct{}{}:
|
||||
default:
|
||||
total := l.dropped.Add(1)
|
||||
l.logger.Debugf("access log send dropped: worker limit reached (total dropped: %d)", total)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer func() { <-l.logSem }()
|
||||
logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout)
|
||||
defer cancel()
|
||||
// Only OIDC sessions have a meaningful user identity.
|
||||
if entry.AuthMechanism != auth.MethodOIDC.String() {
|
||||
entry.UserId = ""
|
||||
entry.UserID = ""
|
||||
}
|
||||
|
||||
var sourceIP string
|
||||
@@ -165,7 +245,7 @@ func (l *Logger) log(entry logEntry) {
|
||||
LogId: entry.ID,
|
||||
AccountId: string(entry.AccountID),
|
||||
Timestamp: now,
|
||||
ServiceId: string(entry.ServiceId),
|
||||
ServiceId: string(entry.ServiceID),
|
||||
Host: entry.Host,
|
||||
Path: entry.Path,
|
||||
DurationMs: entry.DurationMs,
|
||||
@@ -173,7 +253,7 @@ func (l *Logger) log(entry logEntry) {
|
||||
ResponseCode: entry.ResponseCode,
|
||||
SourceIp: sourceIP,
|
||||
AuthMechanism: entry.AuthMechanism,
|
||||
UserId: entry.UserId,
|
||||
UserId: entry.UserID,
|
||||
AuthSuccess: entry.AuthSuccess,
|
||||
BytesUpload: entry.BytesUpload,
|
||||
BytesDownload: entry.BytesDownload,
|
||||
@@ -181,7 +261,7 @@ func (l *Logger) log(entry logEntry) {
|
||||
},
|
||||
}); err != nil {
|
||||
l.logger.WithFields(log.Fields{
|
||||
"service_id": entry.ServiceId,
|
||||
"service_id": entry.ServiceID,
|
||||
"host": entry.Host,
|
||||
"path": entry.Path,
|
||||
"duration": entry.DurationMs,
|
||||
@@ -189,7 +269,7 @@ func (l *Logger) log(entry logEntry) {
|
||||
"response_code": entry.ResponseCode,
|
||||
"source_ip": sourceIP,
|
||||
"auth_mechanism": entry.AuthMechanism,
|
||||
"user_id": entry.UserId,
|
||||
"user_id": entry.UserID,
|
||||
"auth_success": entry.AuthSuccess,
|
||||
"error": err,
|
||||
}).Error("Error sending access log on gRPC connection")
|
||||
@@ -248,7 +328,7 @@ func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleUsage removes usage entries for domains that have been inactive.
|
||||
// cleanupStaleUsage removes usage and deny-rate-limit entries that have been inactive.
|
||||
func (l *Logger) cleanupStaleUsage(ctx context.Context) {
|
||||
ticker := time.NewTicker(usageCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
@@ -258,20 +338,41 @@ func (l *Logger) cleanupStaleUsage(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
l.usageMux.Lock()
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for domain, usage := range l.domainUsage {
|
||||
if now.Sub(usage.lastActivity) > usageInactiveWindow {
|
||||
delete(l.domainUsage, domain)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
l.usageMux.Unlock()
|
||||
|
||||
if removed > 0 {
|
||||
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
|
||||
}
|
||||
l.cleanupDomainUsage(now)
|
||||
l.cleanupDenyBuckets(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) cleanupDomainUsage(now time.Time) {
|
||||
l.usageMux.Lock()
|
||||
defer l.usageMux.Unlock()
|
||||
|
||||
removed := 0
|
||||
for domain, usage := range l.domainUsage {
|
||||
if now.Sub(usage.lastActivity) > usageInactiveWindow {
|
||||
delete(l.domainUsage, domain)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) cleanupDenyBuckets(now time.Time) {
|
||||
l.denyMu.Lock()
|
||||
defer l.denyMu.Unlock()
|
||||
|
||||
removed := 0
|
||||
for key, bucket := range l.denyBuckets {
|
||||
if now.Sub(bucket.lastLogged) > usageInactiveWindow {
|
||||
delete(l.denyBuckets, key)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
l.logger.Debugf("cleaned up %d stale deny rate-limit entries", removed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
// Middleware wraps an HTTP handler to log access entries and resolve client IPs.
|
||||
func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip logging for internal proxy assets (CSS, JS, etc.)
|
||||
@@ -47,8 +48,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
// Create a mutable struct to capture data from downstream handlers.
|
||||
// We pass a pointer in the context - the pointer itself flows down immutably,
|
||||
// but the struct it points to can be mutated by inner handlers.
|
||||
capturedData := &proxy.CapturedData{RequestID: requestID}
|
||||
capturedData := proxy.NewCapturedData(requestID)
|
||||
capturedData.SetClientIP(sourceIp)
|
||||
|
||||
ctx := proxy.WithCapturedData(r.Context(), capturedData)
|
||||
|
||||
start := time.Now()
|
||||
@@ -66,8 +68,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
|
||||
entry := logEntry{
|
||||
ID: requestID,
|
||||
ServiceId: capturedData.GetServiceId(),
|
||||
AccountID: capturedData.GetAccountId(),
|
||||
ServiceID: capturedData.GetServiceID(),
|
||||
AccountID: capturedData.GetAccountID(),
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
@@ -75,14 +77,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
ResponseCode: int32(sw.status),
|
||||
SourceIP: sourceIp,
|
||||
AuthMechanism: capturedData.GetAuthMethod(),
|
||||
UserId: capturedData.GetUserID(),
|
||||
UserID: capturedData.GetUserID(),
|
||||
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
|
||||
BytesUpload: bytesUpload,
|
||||
BytesDownload: bytesDownload,
|
||||
Protocol: ProtocolHTTP,
|
||||
}
|
||||
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
|
||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID())
|
||||
|
||||
l.log(entry)
|
||||
|
||||
|
||||
69
proxy/internal/auth/header.go
Normal file
69
proxy/internal/auth/header.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// ErrHeaderAuthFailed indicates that the header was present but the
|
||||
// credential did not validate. Callers should return 401 instead of
|
||||
// falling through to other auth schemes.
|
||||
var ErrHeaderAuthFailed = errors.New("header authentication failed")
|
||||
|
||||
// Header implements header-based authentication. The proxy checks for the
|
||||
// configured header in each request and validates its value via gRPC.
|
||||
type Header struct {
|
||||
id types.ServiceID
|
||||
accountId types.AccountID
|
||||
headerName string
|
||||
client authenticator
|
||||
}
|
||||
|
||||
// NewHeader creates a Header authentication scheme for the given header name.
|
||||
func NewHeader(client authenticator, id types.ServiceID, accountId types.AccountID, headerName string) Header {
|
||||
return Header{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
headerName: headerName,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Type returns auth.MethodHeader.
|
||||
func (Header) Type() auth.Method {
|
||||
return auth.MethodHeader
|
||||
}
|
||||
|
||||
// Authenticate checks for the configured header in the request. If absent,
|
||||
// returns empty (unauthenticated). If present, validates via gRPC.
|
||||
func (h Header) Authenticate(r *http.Request) (string, string, error) {
|
||||
value := r.Header.Get(h.headerName)
|
||||
if value == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
res, err := h.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: string(h.id),
|
||||
AccountId: string(h.accountId),
|
||||
Request: &proto.AuthenticateRequest_HeaderAuth{
|
||||
HeaderAuth: &proto.HeaderAuthRequest{
|
||||
HeaderValue: value,
|
||||
HeaderName: h.headerName,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("authenticate header: %w", err)
|
||||
}
|
||||
|
||||
if res.GetSuccess() {
|
||||
return res.GetSessionToken(), "", nil
|
||||
}
|
||||
|
||||
return "", "", ErrHeaderAuthFailed
|
||||
}
|
||||
@@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -16,11 +19,16 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// errValidationUnavailable indicates that session validation failed due to
|
||||
// an infrastructure error (e.g. gRPC unavailable), not an invalid token.
|
||||
var errValidationUnavailable = errors.New("session validation unavailable")
|
||||
|
||||
type authenticator interface {
|
||||
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
@@ -40,12 +48,14 @@ type Scheme interface {
|
||||
Authenticate(*http.Request) (token string, promptData string, err error)
|
||||
}
|
||||
|
||||
// DomainConfig holds the authentication and restriction settings for a protected domain.
|
||||
type DomainConfig struct {
|
||||
Schemes []Scheme
|
||||
SessionPublicKey ed25519.PublicKey
|
||||
SessionExpiration time.Duration
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
IPRestrictions *restrict.Filter
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
@@ -54,17 +64,18 @@ type validationResult struct {
|
||||
DeniedReason string
|
||||
}
|
||||
|
||||
// Middleware applies per-domain authentication and IP restriction checks.
|
||||
type Middleware struct {
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]DomainConfig
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
geo restrict.GeoResolver
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware.
|
||||
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
|
||||
// locally without group access checks.
|
||||
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
|
||||
// NewMiddleware creates a new authentication middleware. The sessionValidator is
|
||||
// optional; if nil, OIDC session tokens are validated locally without group access checks.
|
||||
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
@@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl
|
||||
domains: make(map[string]DomainConfig),
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
geo: geo,
|
||||
}
|
||||
}
|
||||
|
||||
// Protect applies authentication middleware to the passed handler.
|
||||
// For each incoming request it will be checked against the middleware's
|
||||
// internal list of protected domains.
|
||||
// If the Host domain in the inbound request is not present, then it will
|
||||
// simply be passed through.
|
||||
// However, if the Host domain is present, then the specified authentication
|
||||
// schemes for that domain will be applied to the request.
|
||||
// In the event that no authentication schemes are defined for the domain,
|
||||
// then the request will also be simply passed through.
|
||||
// Protect wraps next with per-domain authentication and IP restriction checks.
|
||||
// Requests whose Host is not registered pass through unchanged.
|
||||
func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
@@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
config, exists := mw.getDomainConfig(host)
|
||||
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
if !exists || len(config.Schemes) == 0 {
|
||||
if !exists {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
@@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
setCapturedIDs(r, config)
|
||||
|
||||
if !mw.checkIPRestrictions(w, r, config) {
|
||||
return
|
||||
}
|
||||
|
||||
// Domains with no authentication schemes pass through after IP checks.
|
||||
if len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if mw.handleOAuthCallbackError(w, r) {
|
||||
return
|
||||
}
|
||||
@@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.forwardWithHeaderAuth(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
@@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(config.AccountID)
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
cd.SetAccountID(config.AccountID)
|
||||
cd.SetServiceID(config.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// checkIPRestrictions validates the client IP against the domain's IP restrictions.
|
||||
// Uses the resolved client IP from CapturedData (which accounts for trusted proxies)
|
||||
// rather than r.RemoteAddr directly.
|
||||
func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
|
||||
if config.IPRestrictions == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := mw.resolveClientIP(r)
|
||||
if !clientIP.IsValid() {
|
||||
mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr)
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
|
||||
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
|
||||
if verdict == restrict.Allow {
|
||||
return true
|
||||
}
|
||||
|
||||
reason := verdict.String()
|
||||
mw.blockIPRestriction(r, reason)
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr.
|
||||
func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
if ip := cd.GetClientIP(); ip.IsValid() {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if clientIPStr == "" {
|
||||
clientIPStr = r.RemoteAddr
|
||||
}
|
||||
addr, err := netip.ParseAddr(clientIPStr)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return addr.Unmap()
|
||||
}
|
||||
|
||||
// blockIPRestriction sets captured data fields for an IP-restriction block event.
|
||||
func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(reason)
|
||||
}
|
||||
mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr)
|
||||
}
|
||||
|
||||
// handleOAuthCallbackError checks for error query parameters from an OAuth
|
||||
// callback and renders the access denied page if present.
|
||||
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
|
||||
@@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
} else {
|
||||
errDesc = html.EscapeString(errDesc)
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return true
|
||||
@@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
|
||||
// the request is forwarded directly (no redirect), which is important for API clients.
|
||||
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
for _, scheme := range config.Schemes {
|
||||
hdr, ok := scheme.(Header)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
handled := mw.tryHeaderScheme(w, r, host, config, hdr, next)
|
||||
if handled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool {
|
||||
token, _, err := hdr.Authenticate(r)
|
||||
if err != nil {
|
||||
return mw.handleHeaderAuthError(w, r, err)
|
||||
}
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
|
||||
if err != nil {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
status = http.StatusBadGateway
|
||||
msg = "authentication service unavailable"
|
||||
}
|
||||
http.Error(w, msg, status)
|
||||
return true
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
setHeaderCapturedData(r.Context(), result.UserID)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
|
||||
if errors.Is(err, ErrHeaderAuthFailed) {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return true
|
||||
}
|
||||
|
||||
func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
cd := proxy.CapturedDataFromContext(ctx)
|
||||
if cd == nil {
|
||||
return
|
||||
}
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
cd.SetUserID(userID)
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
// On success it sets a session cookie and redirects; on failure it renders the login page.
|
||||
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
|
||||
@@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
status = http.StatusBadGateway
|
||||
msg = "authentication service unavailable"
|
||||
}
|
||||
http.Error(w, msg, status)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// setSessionCookie writes a session cookie with secure defaults.
|
||||
func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) {
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
@@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
|
||||
@@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = DomainConfig{
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
SessionExpiration: expiration,
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDomain unregisters authentication for the given domain.
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
// validateSessionToken validates a session token, optionally checking group access via gRPC.
|
||||
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
|
||||
// For other auth methods (PIN, password), it validates the JWT locally.
|
||||
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
|
||||
// validateSessionToken validates a session token. OIDC tokens with a configured
|
||||
// validator go through gRPC for group access checks; other methods validate locally.
|
||||
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
|
||||
// For OIDC with a session validator, call the gRPC service to check group access
|
||||
if method == auth.MethodOIDC && mw.sessionValidator != nil {
|
||||
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
|
||||
Domain: host,
|
||||
SessionToken: token,
|
||||
})
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
|
||||
return nil, fmt.Errorf("session validation failed")
|
||||
return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err)
|
||||
}
|
||||
if !resp.Valid {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
@@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
}
|
||||
|
||||
// For non-OIDC methods or when no validator is configured, validate JWT locally
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -14,10 +17,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
@@ -52,11 +58,11 @@ func newPassthroughHandler() http.Handler {
|
||||
}
|
||||
|
||||
func TestAddDomain_ValidKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -70,10 +76,10 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -84,10 +90,10 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
@@ -98,11 +104,11 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -113,9 +119,9 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -125,14 +131,14 @@ func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
@@ -144,11 +150,11 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRemoveDomain(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
@@ -159,7 +165,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
|
||||
@@ -171,8 +177,8 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -185,11 +191,11 @@ func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -206,11 +212,11 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -227,16 +233,16 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd)
|
||||
@@ -257,11 +263,11 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
@@ -283,11 +289,11 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
@@ -309,12 +315,12 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -336,7 +342,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -351,7 +357,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -386,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -395,7 +401,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -409,7 +415,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
@@ -431,7 +437,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -451,7 +457,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Return a garbage token that won't validate.
|
||||
@@ -461,7 +467,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
return "invalid-jwt-token", "", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -473,7 +479,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
// 32 random bytes that happen to be valid base64 and correct size
|
||||
// but are actually a valid ed25519 public key length-wise.
|
||||
@@ -485,19 +491,19 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
@@ -511,7 +517,7 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Scheme that always fails authentication (returns empty token)
|
||||
@@ -521,9 +527,9 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Submit wrong PIN - should capture auth method
|
||||
@@ -539,7 +545,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -548,9 +554,9 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Submit wrong password - should capture auth method
|
||||
@@ -566,7 +572,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -575,9 +581,9 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// No credentials submitted - should not capture auth method
|
||||
@@ -658,3 +664,271 @@ func TestWasCredentialSubmitted(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
wantCode int
|
||||
}{
|
||||
{"unparsable address denies", "not-an-ip:1234", http.StatusForbidden},
|
||||
{"empty address denies", "", http.StatusForbidden},
|
||||
{"allowed address passes", "10.1.2.3:5678", http.StatusOK},
|
||||
{"denied address blocked", "192.168.1.1:5678", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
req.Host = "example.com"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, tt.wantCode, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
|
||||
// When CapturedData is set (by the access log middleware, which resolves
|
||||
// trusted proxies), checkIPRestrictions should use that IP, not RemoteAddr.
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// RemoteAddr is a trusted proxy, but CapturedData has the real client IP.
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:5000"
|
||||
req.Host = "example.com"
|
||||
|
||||
cd := proxy.NewCapturedData("")
|
||||
cd.SetClientIP(netip.MustParseAddr("203.0.113.50"))
|
||||
ctx := proxy.WithCapturedData(req.Context(), cd)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "should use CapturedData IP (203.0.113.50), not RemoteAddr (10.0.0.1)")
|
||||
|
||||
// Same request but CapturedData has a blocked IP.
|
||||
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req2.RemoteAddr = "203.0.113.50:5000"
|
||||
req2.Host = "example.com"
|
||||
|
||||
cd2 := proxy.NewCapturedData("")
|
||||
cd2.SetClientIP(netip.MustParseAddr("10.0.0.1"))
|
||||
ctx2 := proxy.WithCapturedData(req2.Context(), cd2)
|
||||
req2 = req2.WithContext(ctx2)
|
||||
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusForbidden, rr2.Code, "should use CapturedData IP (10.0.0.1), not RemoteAddr (203.0.113.50)")
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
// Geo is nil, country restrictions are configured: must deny (fail-close).
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(nil, nil, []string{"US"}, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:5678"
|
||||
req.Host = "example.com"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
|
||||
}
|
||||
|
||||
// mockAuthenticator is a minimal mock for the authenticator gRPC interface
|
||||
// used by the Header scheme.
|
||||
type mockAuthenticator struct {
|
||||
fn func(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.AuthenticateRequest, _ ...grpc.CallOption) (*proto.AuthenticateResponse, error) {
|
||||
return m.fn(ctx, in)
|
||||
}
|
||||
|
||||
// newHeaderSchemeWithToken creates a Header scheme backed by a mock that
|
||||
// returns a signed session token when the expected header value is provided.
|
||||
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && ha.GetHeaderValue() == expectedValue {
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
return NewHeader(mock, "svc1", "acc1", headerName)
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
var backendCalled bool
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)
|
||||
req.Header.Set("X-API-Key", "secret-key")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.True(t, backendCalled, "backend should be called directly for header auth (no redirect)")
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "ok", rec.Body.String())
|
||||
|
||||
// Session cookie should be set.
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == auth.SessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, sessionCookie, "session cookie should be set after successful header auth")
|
||||
assert.True(t, sessionCookie.HttpOnly)
|
||||
assert.True(t, sessionCookie.Secure)
|
||||
|
||||
assert.Equal(t, "header-user", capturedData.GetUserID())
|
||||
assert.Equal(t, "header", capturedData.GetAuthMethod())
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
// Also add a PIN scheme so we can verify fallthrough behavior.
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// No X-API-Key header: should fall through to PIN login page (401).
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "missing header should fall through to login page")
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("X-API-Key", "wrong-key")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Equal(t, "header", capturedData.GetAuthMethod())
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, _ *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
return nil, errors.New("gRPC unavailable")
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.Header.Set("X-API-Key", "some-key")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request with header auth.
|
||||
req1 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req1.Header.Set("X-API-Key", "secret-key")
|
||||
req1 = req1.WithContext(proxy.WithCapturedData(req1.Context(), proxy.NewCapturedData("")))
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
|
||||
// Extract session cookie.
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range rec1.Result().Cookies() {
|
||||
if c.Name == auth.SessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, sessionCookie)
|
||||
|
||||
// Second request with only the session cookie (no header).
|
||||
capturedData2 := proxy.NewCapturedData("")
|
||||
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/other", nil)
|
||||
req2.AddCookie(sessionCookie)
|
||||
req2 = req2.WithContext(proxy.WithCapturedData(req2.Context(), capturedData2))
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec2.Code)
|
||||
assert.Equal(t, "header-user", capturedData2.GetUserID())
|
||||
assert.Equal(t, "header", capturedData2.GetAuthMethod())
|
||||
}
|
||||
|
||||
264
proxy/internal/geolocation/download.go
Normal file
264
proxy/internal/geolocation/download.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package geolocation
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
|
||||
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
|
||||
mmdbInnerName = "GeoLite2-City.mmdb"
|
||||
|
||||
downloadTimeout = 2 * time.Minute
|
||||
maxMMDBSize = 256 << 20 // 256 MB
|
||||
)
|
||||
|
||||
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
|
||||
// it downloads from pkgs.netbird.io with SHA256 verification.
|
||||
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
|
||||
if err := os.MkdirAll(dataDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
|
||||
}
|
||||
|
||||
pattern := filepath.Join(dataDir, mmdbGlob)
|
||||
if files, _ := filepath.Glob(pattern); len(files) > 0 {
|
||||
mmdbPath := files[len(files)-1]
|
||||
logger.Debugf("using existing geolocation database: %s", mmdbPath)
|
||||
return mmdbPath, nil
|
||||
}
|
||||
|
||||
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
|
||||
return downloadMMDB(logger, dataDir)
|
||||
}
|
||||
|
||||
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
|
||||
client := &http.Client{Timeout: downloadTimeout}
|
||||
|
||||
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get remote filename: %w", err)
|
||||
}
|
||||
|
||||
mmdbFilename := deriveMMDBFilename(datedName)
|
||||
mmdbPath := filepath.Join(dataDir, mmdbFilename)
|
||||
|
||||
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
checksumFile := filepath.Join(tmp, "checksum.sha256")
|
||||
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
|
||||
return "", fmt.Errorf("download checksum: %w", err)
|
||||
}
|
||||
|
||||
expectedHash, err := readChecksumFile(checksumFile)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read checksum: %w", err)
|
||||
}
|
||||
|
||||
tarFile := filepath.Join(tmp, datedName)
|
||||
logger.Debugf("downloading geolocation database (%s)", datedName)
|
||||
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
|
||||
return "", fmt.Errorf("download database: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySHA256(tarFile, expectedHash); err != nil {
|
||||
return "", fmt.Errorf("verify database checksum: %w", err)
|
||||
}
|
||||
|
||||
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
|
||||
return "", fmt.Errorf("extract database: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("geolocation database downloaded: %s", mmdbPath)
|
||||
return mmdbPath, nil
|
||||
}
|
||||
|
||||
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
|
||||
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
|
||||
func deriveMMDBFilename(tarName string) string {
|
||||
base, _, _ := strings.Cut(tarName, ".")
|
||||
if !strings.Contains(base, "_") {
|
||||
return "GeoLite2-City.mmdb"
|
||||
}
|
||||
return base + ".mmdb"
|
||||
}
|
||||
|
||||
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
|
||||
resp, err := client.Head(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
cd := resp.Header.Get("Content-Disposition")
|
||||
if cd == "" {
|
||||
return "", errors.New("no Content-Disposition header")
|
||||
}
|
||||
|
||||
_, params, err := mime.ParseMediaType(cd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse Content-Disposition: %w", err)
|
||||
}
|
||||
|
||||
name := filepath.Base(params["filename"])
|
||||
if name == "" || name == "." {
|
||||
return "", errors.New("no filename in Content-Disposition")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func downloadToFile(client *http.Client, url, dest string) error {
|
||||
resp, err := client.Get(url) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
f, err := os.Create(dest) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
|
||||
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readChecksumFile(path string) (string, error) {
|
||||
f, err := os.Open(path) //nolint:gosec
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
if scanner.Scan() {
|
||||
parts := strings.Fields(scanner.Text())
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", errors.New("empty checksum file")
|
||||
}
|
||||
|
||||
func verifySHA256(path, expected string) error {
|
||||
f, err := os.Open(path) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
actual := fmt.Sprintf("%x", h.Sum(nil))
|
||||
if actual != expected {
|
||||
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
|
||||
f, err := os.Open(tarGZPath) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
|
||||
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
|
||||
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
|
||||
}
|
||||
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s not found in archive", mmdbInnerName)
|
||||
}
|
||||
|
||||
// extractToFileAtomic writes r to a temporary file in the same directory as
|
||||
// destPath, then renames it into place so a crash never leaves a truncated file.
|
||||
func extractToFileAtomic(r io.Reader, destPath string) error {
|
||||
dir := filepath.Dir(destPath)
|
||||
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
|
||||
if closeErr := tmp.Close(); closeErr != nil {
|
||||
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
|
||||
}
|
||||
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
||||
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
||||
}
|
||||
return fmt.Errorf("write mmdb: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
||||
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
||||
}
|
||||
return fmt.Errorf("close temp file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
||||
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
||||
}
|
||||
return fmt.Errorf("rename to %s: %w", destPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
152
proxy/internal/geolocation/geolocation.go
Normal file
152
proxy/internal/geolocation/geolocation.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases.
|
||||
package geolocation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables geolocation lookups entirely when set to a truthy value.
|
||||
EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION"
|
||||
|
||||
mmdbGlob = "GeoLite2-City_*.mmdb"
|
||||
)
|
||||
|
||||
type record struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
City struct {
|
||||
Names struct {
|
||||
En string `maxminddb:"en"`
|
||||
} `maxminddb:"names"`
|
||||
} `maxminddb:"city"`
|
||||
Subdivisions []struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
Names struct {
|
||||
En string `maxminddb:"en"`
|
||||
} `maxminddb:"names"`
|
||||
} `maxminddb:"subdivisions"`
|
||||
}
|
||||
|
||||
// Result holds the outcome of a geo lookup.
|
||||
type Result struct {
|
||||
CountryCode string
|
||||
CityName string
|
||||
SubdivisionCode string
|
||||
SubdivisionName string
|
||||
}
|
||||
|
||||
// Lookup provides IP geolocation lookups.
|
||||
type Lookup struct {
|
||||
mu sync.RWMutex
|
||||
db *maxminddb.Reader
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir.
|
||||
// Returns nil without error if geolocation is disabled via environment
|
||||
// variable, no data directory is configured, or the download fails
|
||||
// (graceful degradation: country restrictions will deny all requests).
|
||||
func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) {
|
||||
if isDisabledByEnv(logger) {
|
||||
logger.Info("geolocation disabled via environment variable")
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
if dataDir == "" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
mmdbPath, err := ensureMMDB(logger, dataDir)
|
||||
if err != nil {
|
||||
logger.Warnf("geolocation database unavailable: %v", err)
|
||||
logger.Warn("country-based access restrictions will deny all requests until a database is available")
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
db, err := maxminddb.Open(mmdbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err)
|
||||
}
|
||||
|
||||
logger.Infof("geolocation database loaded from %s", mmdbPath)
|
||||
return &Lookup{db: db, logger: logger}, nil
|
||||
}
|
||||
|
||||
// LookupAddr returns the country ISO code and city name for the given IP.
|
||||
// Returns an empty Result if the database is nil or the lookup fails.
|
||||
func (l *Lookup) LookupAddr(addr netip.Addr) Result {
|
||||
if l == nil {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
if l.db == nil {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
addr = addr.Unmap()
|
||||
|
||||
var rec record
|
||||
if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil {
|
||||
l.logger.Debugf("geolocation lookup %s: %v", addr, err)
|
||||
return Result{}
|
||||
}
|
||||
r := Result{
|
||||
CountryCode: rec.Country.ISOCode,
|
||||
CityName: rec.City.Names.En,
|
||||
}
|
||||
if len(rec.Subdivisions) > 0 {
|
||||
r.SubdivisionCode = rec.Subdivisions[0].ISOCode
|
||||
r.SubdivisionName = rec.Subdivisions[0].Names.En
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Available reports whether the lookup has a loaded database.
|
||||
func (l *Lookup) Available() bool {
|
||||
if l == nil {
|
||||
return false
|
||||
}
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
return l.db != nil
|
||||
}
|
||||
|
||||
// Close releases the database resources.
|
||||
func (l *Lookup) Close() error {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.db != nil {
|
||||
err := l.db.Close()
|
||||
l.db = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isDisabledByEnv(logger *log.Logger) bool {
|
||||
val := os.Getenv(EnvDisable)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
logger.Warnf("parse %s=%q: %v", EnvDisable, val, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
type requestContextKey string
|
||||
|
||||
const (
|
||||
serviceIdKey requestContextKey = "serviceId"
|
||||
accountIdKey requestContextKey = "accountId"
|
||||
capturedDataKey requestContextKey = "capturedData"
|
||||
)
|
||||
|
||||
@@ -47,112 +45,117 @@ func (o ResponseOrigin) String() string {
|
||||
// to pass data back up the middleware chain.
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
RequestID string
|
||||
ServiceId types.ServiceID
|
||||
AccountId types.AccountID
|
||||
Origin ResponseOrigin
|
||||
ClientIP netip.Addr
|
||||
UserID string
|
||||
AuthMethod string
|
||||
requestID string
|
||||
serviceID types.ServiceID
|
||||
accountID types.AccountID
|
||||
origin ResponseOrigin
|
||||
clientIP netip.Addr
|
||||
userID string
|
||||
authMethod string
|
||||
}
|
||||
|
||||
// GetRequestID safely gets the request ID
|
||||
// NewCapturedData creates a CapturedData with the given request ID.
|
||||
func NewCapturedData(requestID string) *CapturedData {
|
||||
return &CapturedData{requestID: requestID}
|
||||
}
|
||||
|
||||
// GetRequestID returns the request ID.
|
||||
func (c *CapturedData) GetRequestID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.RequestID
|
||||
return c.requestID
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
|
||||
// SetServiceID sets the service ID.
|
||||
func (c *CapturedData) SetServiceID(serviceID types.ServiceID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ServiceId = serviceId
|
||||
c.serviceID = serviceID
|
||||
}
|
||||
|
||||
// GetServiceId safely gets the service ID
|
||||
func (c *CapturedData) GetServiceId() types.ServiceID {
|
||||
// GetServiceID returns the service ID.
|
||||
func (c *CapturedData) GetServiceID() types.ServiceID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ServiceId
|
||||
return c.serviceID
|
||||
}
|
||||
|
||||
// SetAccountId safely sets the account ID
|
||||
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
|
||||
// SetAccountID sets the account ID.
|
||||
func (c *CapturedData) SetAccountID(accountID types.AccountID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AccountId = accountId
|
||||
c.accountID = accountID
|
||||
}
|
||||
|
||||
// GetAccountId safely gets the account ID
|
||||
func (c *CapturedData) GetAccountId() types.AccountID {
|
||||
// GetAccountID returns the account ID.
|
||||
func (c *CapturedData) GetAccountID() types.AccountID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AccountId
|
||||
return c.accountID
|
||||
}
|
||||
|
||||
// SetOrigin safely sets the response origin
|
||||
// SetOrigin sets the response origin.
|
||||
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Origin = origin
|
||||
c.origin = origin
|
||||
}
|
||||
|
||||
// GetOrigin safely gets the response origin
|
||||
// GetOrigin returns the response origin.
|
||||
func (c *CapturedData) GetOrigin() ResponseOrigin {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.Origin
|
||||
return c.origin
|
||||
}
|
||||
|
||||
// SetClientIP safely sets the resolved client IP.
|
||||
// SetClientIP sets the resolved client IP.
|
||||
func (c *CapturedData) SetClientIP(ip netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ClientIP = ip
|
||||
c.clientIP = ip
|
||||
}
|
||||
|
||||
// GetClientIP safely gets the resolved client IP.
|
||||
// GetClientIP returns the resolved client IP.
|
||||
func (c *CapturedData) GetClientIP() netip.Addr {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ClientIP
|
||||
return c.clientIP
|
||||
}
|
||||
|
||||
// SetUserID safely sets the authenticated user ID.
|
||||
// SetUserID sets the authenticated user ID.
|
||||
func (c *CapturedData) SetUserID(userID string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.UserID = userID
|
||||
c.userID = userID
|
||||
}
|
||||
|
||||
// GetUserID safely gets the authenticated user ID.
|
||||
// GetUserID returns the authenticated user ID.
|
||||
func (c *CapturedData) GetUserID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.UserID
|
||||
return c.userID
|
||||
}
|
||||
|
||||
// SetAuthMethod safely sets the authentication method used.
|
||||
// SetAuthMethod sets the authentication method used.
|
||||
func (c *CapturedData) SetAuthMethod(method string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AuthMethod = method
|
||||
c.authMethod = method
|
||||
}
|
||||
|
||||
// GetAuthMethod safely gets the authentication method used.
|
||||
// GetAuthMethod returns the authentication method used.
|
||||
func (c *CapturedData) GetAuthMethod() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AuthMethod
|
||||
return c.authMethod
|
||||
}
|
||||
|
||||
// WithCapturedData adds a CapturedData struct to the context
|
||||
// WithCapturedData adds a CapturedData struct to the context.
|
||||
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
|
||||
return context.WithValue(ctx, capturedDataKey, data)
|
||||
}
|
||||
|
||||
// CapturedDataFromContext retrieves the CapturedData from context
|
||||
// CapturedDataFromContext retrieves the CapturedData from context.
|
||||
func CapturedDataFromContext(ctx context.Context) *CapturedData {
|
||||
v := ctx.Value(capturedDataKey)
|
||||
data, ok := v.(*CapturedData)
|
||||
@@ -161,28 +164,3 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
|
||||
return context.WithValue(ctx, serviceIdKey, serviceId)
|
||||
}
|
||||
|
||||
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
|
||||
v := ctx.Value(serviceIdKey)
|
||||
serviceId, ok := v.(types.ServiceID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return serviceId
|
||||
}
|
||||
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIdKey, accountId)
|
||||
}
|
||||
|
||||
func AccountIdFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIdKey)
|
||||
accountId, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountId
|
||||
}
|
||||
|
||||
@@ -66,19 +66,16 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), result.serviceID)
|
||||
// Set the accountId in the context for later retrieval (for middleware).
|
||||
ctx = withAccountId(ctx, result.accountID)
|
||||
// Set the accountId in the context for the roundtripper to use.
|
||||
ctx := r.Context()
|
||||
// Set the account ID in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, result.accountID)
|
||||
|
||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// Populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||
// pointer in the context, and mutate the struct here so outer middleware can read it.
|
||||
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
|
||||
capturedData.SetServiceId(result.serviceID)
|
||||
capturedData.SetAccountId(result.accountID)
|
||||
capturedData.SetServiceID(result.serviceID)
|
||||
capturedData.SetAccountID(result.accountID)
|
||||
}
|
||||
|
||||
pt := result.target
|
||||
@@ -96,10 +93,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
|
||||
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders, result.stripAuthHeaders),
|
||||
Transport: p.transport,
|
||||
FlushInterval: -1,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
ErrorHandler: p.proxyErrorHandler,
|
||||
}
|
||||
if result.rewriteRedirects {
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
|
||||
@@ -113,7 +110,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// When passHostHeader is true, the original client Host header is preserved
|
||||
// instead of being rewritten to the backend's address.
|
||||
// The pathRewrite parameter controls how the request path is transformed.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string, stripAuthHeaders []string) func(r *httputil.ProxyRequest) {
|
||||
return func(r *httputil.ProxyRequest) {
|
||||
switch pathRewrite {
|
||||
case PathRewritePreserve:
|
||||
@@ -137,6 +134,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Host = target.Host
|
||||
}
|
||||
|
||||
for _, h := range stripAuthHeaders {
|
||||
r.Out.Header.Del(h)
|
||||
}
|
||||
|
||||
for k, v := range customHeaders {
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
@@ -305,7 +306,7 @@ func extractForwardedPort(host, resolvedProto string) string {
|
||||
|
||||
// proxyErrorHandler handles errors from the reverse proxy and serves
|
||||
// user-friendly error pages instead of raw error responses.
|
||||
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
func (p *ReverseProxy) proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if cd := CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(OriginProxyError)
|
||||
}
|
||||
@@ -313,7 +314,7 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
clientIP := getClientIP(r)
|
||||
title, message, code, status := classifyProxyError(err)
|
||||
|
||||
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
|
||||
p.logger.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
|
||||
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
|
||||
|
||||
web.ServeErrorPage(w, r, code, title, message, requestID, status)
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
// No TLS, but forced to https
|
||||
|
||||
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("forced http proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "http"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
||||
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
||||
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
||||
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
|
||||
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
// What the behavior WOULD be if target URL had no path (true stripping)
|
||||
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
|
||||
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
// Root path "/" — no stripping expected
|
||||
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.com:443/")
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -551,7 +551,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("preserve keeps full request path", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -561,7 +561,7 @@ func TestRewriteFunc_PreservePath(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("preserve with root matchedPath", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -579,7 +579,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||
"X-Custom-Auth": "token-abc",
|
||||
"X-Env": "production",
|
||||
}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -589,7 +589,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("nil customHeaders is fine", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -599,7 +599,7 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||
|
||||
t.Run("custom headers override existing request headers", func(t *testing.T) {
|
||||
headers := map[string]string{"X-Override": "new-value"}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.Header.Set("X-Override", "old-value")
|
||||
|
||||
@@ -609,11 +609,38 @@ func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_StripsAuthorizationHeader(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("strips incoming Authorization when no custom Authorization set", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil, []string{"Authorization"})
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.Header.Set("Authorization", "Bearer proxy-token")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get("Authorization"), "Authorization should be stripped")
|
||||
})
|
||||
|
||||
t.Run("custom Authorization replaces incoming", func(t *testing.T) {
|
||||
headers := map[string]string{"Authorization": "Basic YmFja2VuZDpzZWNyZXQ="}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers, []string{"Authorization"})
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.Header.Set("Authorization", "Bearer proxy-token")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "Basic YmFja2VuZDpzZWNyZXQ=", pr.Out.Header.Get("Authorization"),
|
||||
"backend Authorization from custom headers should be set")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"}, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
@@ -38,6 +38,11 @@ type Mapping struct {
|
||||
Paths map[string]*PathTarget
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
// StripAuthHeaders are header names used for header-based auth.
|
||||
// These headers are stripped from requests before forwarding.
|
||||
StripAuthHeaders []string
|
||||
// sortedPaths caches the paths sorted by length (longest first).
|
||||
sortedPaths []string
|
||||
}
|
||||
|
||||
type targetResult struct {
|
||||
@@ -47,6 +52,7 @@ type targetResult struct {
|
||||
accountID types.AccountID
|
||||
passHostHeader bool
|
||||
rewriteRedirects bool
|
||||
stripAuthHeaders []string
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
|
||||
@@ -65,16 +71,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
|
||||
paths := make([]string, 0, len(m.Paths))
|
||||
for path := range m.Paths {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Slice(paths, func(i, j int) bool {
|
||||
return len(paths[i]) > len(paths[j])
|
||||
})
|
||||
|
||||
for _, path := range paths {
|
||||
for _, path := range m.sortedPaths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
pt := m.Paths[path]
|
||||
if pt == nil || pt.URL == nil {
|
||||
@@ -89,6 +86,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
|
||||
accountID: m.AccountID,
|
||||
passHostHeader: m.PassHostHeader,
|
||||
rewriteRedirects: m.RewriteRedirects,
|
||||
stripAuthHeaders: m.StripAuthHeaders,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
@@ -96,7 +94,18 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
// AddMapping registers a host-to-backend mapping for the reverse proxy.
|
||||
func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
// Sort paths longest-first to match the most specific route first.
|
||||
paths := make([]string, 0, len(m.Paths))
|
||||
for path := range m.Paths {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Slice(paths, func(i, j int) bool {
|
||||
return len(paths[i]) > len(paths[j])
|
||||
})
|
||||
m.sortedPaths = paths
|
||||
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
p.mappings[m.Host] = m
|
||||
|
||||
183
proxy/internal/restrict/restrict.go
Normal file
183
proxy/internal/restrict/restrict.go
Normal file
@@ -0,0 +1,183 @@
|
||||
// Package restrict provides connection-level access control based on
|
||||
// IP CIDR ranges and geolocation (country codes).
|
||||
package restrict
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
||||
)
|
||||
|
||||
// GeoResolver resolves an IP address to geographic information.
|
||||
type GeoResolver interface {
|
||||
LookupAddr(addr netip.Addr) geolocation.Result
|
||||
Available() bool
|
||||
}
|
||||
|
||||
// Filter evaluates IP restrictions. CIDR checks are performed first
|
||||
// (cheap), followed by country lookups (more expensive) only when needed.
|
||||
type Filter struct {
|
||||
AllowedCIDRs []netip.Prefix
|
||||
BlockedCIDRs []netip.Prefix
|
||||
AllowedCountries []string
|
||||
BlockedCountries []string
|
||||
}
|
||||
|
||||
// ParseFilter builds a Filter from the raw string slices. Returns nil
|
||||
// if all slices are empty.
|
||||
func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter {
|
||||
if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 &&
|
||||
len(allowedCountries) == 0 && len(blockedCountries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
f := &Filter{
|
||||
AllowedCountries: normalizeCountryCodes(allowedCountries),
|
||||
BlockedCountries: normalizeCountryCodes(blockedCountries),
|
||||
}
|
||||
for _, cidr := range allowedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err)
|
||||
continue
|
||||
}
|
||||
f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked())
|
||||
}
|
||||
for _, cidr := range blockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err)
|
||||
continue
|
||||
}
|
||||
f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked())
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func normalizeCountryCodes(codes []string) []string {
|
||||
if len(codes) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(codes))
|
||||
for i, c := range codes {
|
||||
out[i] = strings.ToUpper(c)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Verdict is the result of an access check.
|
||||
type Verdict int
|
||||
|
||||
const (
|
||||
// Allow indicates the address passed all checks.
|
||||
Allow Verdict = iota
|
||||
// DenyCIDR indicates the address was blocked by a CIDR rule.
|
||||
DenyCIDR
|
||||
// DenyCountry indicates the address was blocked by a country rule.
|
||||
DenyCountry
|
||||
// DenyGeoUnavailable indicates that country restrictions are configured
|
||||
// but the geo lookup is unavailable.
|
||||
DenyGeoUnavailable
|
||||
)
|
||||
|
||||
// String returns the deny reason string matching the HTTP auth mechanism names.
|
||||
func (v Verdict) String() string {
|
||||
switch v {
|
||||
case Allow:
|
||||
return "allow"
|
||||
case DenyCIDR:
|
||||
return "ip_restricted"
|
||||
case DenyCountry:
|
||||
return "country_restricted"
|
||||
case DenyGeoUnavailable:
|
||||
return "geo_unavailable"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Check evaluates whether addr is permitted. CIDR rules are evaluated
|
||||
// first because they are O(n) prefix comparisons. Country rules run
|
||||
// only when CIDR checks pass and require a geo lookup.
|
||||
func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict {
|
||||
if f == nil {
|
||||
return Allow
|
||||
}
|
||||
|
||||
// Normalize v4-mapped-v6 (e.g. ::ffff:10.1.2.3) to plain v4 so that
|
||||
// IPv4 CIDR rules match regardless of how the address was received.
|
||||
addr = addr.Unmap()
|
||||
|
||||
if v := f.checkCIDR(addr); v != Allow {
|
||||
return v
|
||||
}
|
||||
return f.checkCountry(addr, geo)
|
||||
}
|
||||
|
||||
func (f *Filter) checkCIDR(addr netip.Addr) Verdict {
|
||||
if len(f.AllowedCIDRs) > 0 {
|
||||
allowed := false
|
||||
for _, prefix := range f.AllowedCIDRs {
|
||||
if prefix.Contains(addr) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
return DenyCIDR
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range f.BlockedCIDRs {
|
||||
if prefix.Contains(addr) {
|
||||
return DenyCIDR
|
||||
}
|
||||
}
|
||||
return Allow
|
||||
}
|
||||
|
||||
func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict {
|
||||
if len(f.AllowedCountries) == 0 && len(f.BlockedCountries) == 0 {
|
||||
return Allow
|
||||
}
|
||||
|
||||
if geo == nil || !geo.Available() {
|
||||
return DenyGeoUnavailable
|
||||
}
|
||||
|
||||
result := geo.LookupAddr(addr)
|
||||
if result.CountryCode == "" {
|
||||
// Unknown country: deny if an allowlist is active, allow otherwise.
|
||||
// Blocklists are best-effort: unknown countries pass through since
|
||||
// the default policy is allow.
|
||||
if len(f.AllowedCountries) > 0 {
|
||||
return DenyCountry
|
||||
}
|
||||
return Allow
|
||||
}
|
||||
|
||||
if len(f.AllowedCountries) > 0 {
|
||||
if !slices.Contains(f.AllowedCountries, result.CountryCode) {
|
||||
return DenyCountry
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(f.BlockedCountries, result.CountryCode) {
|
||||
return DenyCountry
|
||||
}
|
||||
|
||||
return Allow
|
||||
}
|
||||
|
||||
// HasRestrictions returns true if any restriction rules are configured.
|
||||
func (f *Filter) HasRestrictions() bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 ||
|
||||
len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0
|
||||
}
|
||||
278
proxy/internal/restrict/restrict_test.go
Normal file
278
proxy/internal/restrict/restrict_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package restrict
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
||||
)
|
||||
|
||||
type mockGeo struct {
|
||||
countries map[string]string
|
||||
}
|
||||
|
||||
func (m *mockGeo) LookupAddr(addr netip.Addr) geolocation.Result {
|
||||
return geolocation.Result{CountryCode: m.countries[addr.String()]}
|
||||
}
|
||||
|
||||
func (m *mockGeo) Available() bool { return true }
|
||||
|
||||
func newMockGeo(entries map[string]string) *mockGeo {
|
||||
return &mockGeo{countries: entries}
|
||||
}
|
||||
|
||||
func TestFilter_Check_NilFilter(t *testing.T) {
|
||||
var f *Filter
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_AllowedCIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_BlockedCIDR(t *testing.T) {
|
||||
f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil)
|
||||
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil)
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "not in allowlist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_AllowedCountry(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "US",
|
||||
"2.2.2.2": "DE",
|
||||
"3.3.3.3": "CN",
|
||||
})
|
||||
f := ParseFilter(nil, nil, []string{"US", "DE"}, nil)
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_BlockedCountry(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "CN",
|
||||
"2.2.2.2": "RU",
|
||||
"3.3.3.3": "US",
|
||||
})
|
||||
f := ParseFilter(nil, nil, nil, []string{"CN", "RU"})
|
||||
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "US not in blocklist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "US",
|
||||
"2.2.2.2": "DE",
|
||||
"3.3.3.3": "CN",
|
||||
})
|
||||
// Allow US and DE, but block DE explicitly.
|
||||
f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("3.3.3.3"), geo), "CN not in allowlist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "US",
|
||||
})
|
||||
f := ParseFilter(nil, nil, []string{"US"}, nil)
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active")
|
||||
}
|
||||
|
||||
func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "CN",
|
||||
})
|
||||
f := ParseFilter(nil, nil, nil, []string{"CN"})
|
||||
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active")
|
||||
}
|
||||
|
||||
func TestFilter_Check_CountryWithoutGeo(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, []string{"US"}, nil)
|
||||
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, nil, []string{"CN"})
|
||||
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_GeoUnavailable(t *testing.T) {
|
||||
geo := &unavailableGeo{}
|
||||
|
||||
f := ParseFilter(nil, nil, []string{"US"}, nil)
|
||||
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist")
|
||||
|
||||
f2 := ParseFilter(nil, nil, nil, []string{"CN"})
|
||||
assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
|
||||
// CIDR-only filter should never touch geo, so nil geo is fine.
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"10.1.2.3": "CN",
|
||||
"10.2.3.4": "US",
|
||||
})
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"})
|
||||
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), geo), "CIDR denied before country check")
|
||||
}
|
||||
|
||||
func TestParseFilter_Empty(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, nil, nil)
|
||||
assert.Nil(t, f)
|
||||
}
|
||||
|
||||
func TestParseFilter_InvalidCIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil)
|
||||
|
||||
assert.NotNil(t, f)
|
||||
assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped")
|
||||
assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), f.AllowedCIDRs[0])
|
||||
}
|
||||
|
||||
func TestFilter_HasRestrictions(t *testing.T) {
|
||||
assert.False(t, (*Filter)(nil).HasRestrictions())
|
||||
assert.False(t, (&Filter{}).HasRestrictions())
|
||||
assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions())
|
||||
assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions())
|
||||
}
|
||||
|
||||
func TestFilter_Check_IPv6CIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil)
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 addr not in v6 allowlist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_IPv4MappedIPv6(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
|
||||
// A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR.
|
||||
v4mapped := netip.MustParseAddr("::ffff:10.1.2.3")
|
||||
assert.True(t, v4mapped.Is4In6(), "precondition: address is v4-in-v6")
|
||||
assert.Equal(t, Allow, f.Check(v4mapped, nil), "v4-mapped-v6 must match v4 CIDR after Unmap")
|
||||
|
||||
v4mappedOutside := netip.MustParseAddr("::ffff:192.168.1.1")
|
||||
assert.Equal(t, DenyCIDR, f.Check(v4mappedOutside, nil), "v4-mapped-v6 outside v4 CIDR")
|
||||
}
|
||||
|
||||
func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil)
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("::ffff:10.1.2.3"), nil), "v4-mapped matches v4 CIDR")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil), "v4 not in either CIDR")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("fe80::1"), nil), "v6 not in either CIDR")
|
||||
}
|
||||
|
||||
func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) {
|
||||
// 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24.
|
||||
f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil)
|
||||
assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0])
|
||||
|
||||
// Verify it still matches correctly.
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.100"), nil))
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("1.1.2.1"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "US",
|
||||
"2.2.2.2": "DE",
|
||||
"3.3.3.3": "CN",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedCountries []string
|
||||
blockedCountries []string
|
||||
addr string
|
||||
want Verdict
|
||||
}{
|
||||
{
|
||||
name: "lowercase allowlist matches uppercase MaxMind code",
|
||||
allowedCountries: []string{"us", "de"},
|
||||
addr: "1.1.1.1",
|
||||
want: Allow,
|
||||
},
|
||||
{
|
||||
name: "mixed-case allowlist matches",
|
||||
allowedCountries: []string{"Us", "dE"},
|
||||
addr: "2.2.2.2",
|
||||
want: Allow,
|
||||
},
|
||||
{
|
||||
name: "lowercase allowlist rejects non-matching country",
|
||||
allowedCountries: []string{"us", "de"},
|
||||
addr: "3.3.3.3",
|
||||
want: DenyCountry,
|
||||
},
|
||||
{
|
||||
name: "lowercase blocklist blocks matching country",
|
||||
blockedCountries: []string{"cn"},
|
||||
addr: "3.3.3.3",
|
||||
want: DenyCountry,
|
||||
},
|
||||
{
|
||||
name: "mixed-case blocklist blocks matching country",
|
||||
blockedCountries: []string{"Cn"},
|
||||
addr: "3.3.3.3",
|
||||
want: DenyCountry,
|
||||
},
|
||||
{
|
||||
name: "lowercase blocklist does not block non-matching country",
|
||||
blockedCountries: []string{"cn"},
|
||||
addr: "1.1.1.1",
|
||||
want: Allow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries)
|
||||
got := f.Check(netip.MustParseAddr(tc.addr), geo)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// unavailableGeo simulates a GeoResolver whose database is not loaded.
|
||||
type unavailableGeo struct{}
|
||||
|
||||
func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} }
|
||||
func (u *unavailableGeo) Available() bool { return false }
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
@@ -20,6 +22,10 @@ import (
|
||||
// timeout is configured.
|
||||
const defaultDialTimeout = 30 * time.Second
|
||||
|
||||
// errAccessRestricted is returned by relayTCP for access restriction
|
||||
// denials so callers can skip warn-level logging (already logged at debug).
|
||||
var errAccessRestricted = errors.New("rejected by access restrictions")
|
||||
|
||||
// SNIHost is a typed key for SNI hostname lookups.
|
||||
type SNIHost string
|
||||
|
||||
@@ -64,6 +70,11 @@ type Route struct {
|
||||
// DialTimeout overrides the default dial timeout for this route.
|
||||
// Zero uses defaultDialTimeout.
|
||||
DialTimeout time.Duration
|
||||
// SessionIdleTimeout overrides the default idle timeout for relay connections.
|
||||
// Zero uses DefaultIdleTimeout.
|
||||
SessionIdleTimeout time.Duration
|
||||
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
|
||||
Filter *restrict.Filter
|
||||
}
|
||||
|
||||
// l4Logger sends layer-4 access log entries to the management server.
|
||||
@@ -99,6 +110,7 @@ type Router struct {
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
geo restrict.GeoResolver
|
||||
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
||||
// service derive from its context; canceling it kills them immediately.
|
||||
svcCtxs map[types.ServiceID]context.Context
|
||||
@@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener {
|
||||
// stored and resolved by priority at lookup time (HTTP > TCP).
|
||||
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
||||
func (r *Router) AddRoute(host SNIHost, route Route) {
|
||||
host = SNIHost(strings.ToLower(string(host)))
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
@@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) {
|
||||
// Active relay connections for the service are closed immediately.
|
||||
// If other routes remain for the host, they are preserved.
|
||||
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
|
||||
host = SNIHost(strings.ToLower(string(host)))
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
@@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
host := SNIHost(sni)
|
||||
host := SNIHost(strings.ToLower(sni))
|
||||
route, ok := r.lookupRoute(host)
|
||||
if !ok {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
@@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"sni": host,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
}).Warnf("TCP relay: %v", err)
|
||||
if !errors.Is(err, errAccessRestricted) {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"sni": host,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
}).Warnf("TCP relay: %v", err)
|
||||
}
|
||||
_ = wrapped.Close()
|
||||
}
|
||||
}
|
||||
@@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
|
||||
if fb != nil {
|
||||
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"service_id": fb.ServiceID,
|
||||
"target": fb.Target,
|
||||
}).Warnf("TCP relay (fallback): %v", err)
|
||||
if !errors.Is(err, errAccessRestricted) {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"service_id": fb.ServiceID,
|
||||
"target": fb.Target,
|
||||
}).Warnf("TCP relay (fallback): %v", err)
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
@@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetGeo sets the geolocation lookup used for country-based restrictions.
|
||||
func (r *Router) SetGeo(geo restrict.GeoResolver) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.geo = geo
|
||||
}
|
||||
|
||||
// checkRestrictions evaluates the route's access filter against the
|
||||
// connection's remote address. Returns Allow if the connection is
|
||||
// permitted, or a deny verdict indicating the reason.
|
||||
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
|
||||
if route.Filter == nil {
|
||||
return restrict.Allow
|
||||
}
|
||||
|
||||
addr, err := addrFromConn(conn)
|
||||
if err != nil {
|
||||
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
|
||||
return restrict.DenyCIDR
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
geo := r.geo
|
||||
r.mu.RUnlock()
|
||||
|
||||
return route.Filter.Check(addr, geo)
|
||||
}
|
||||
|
||||
// relayTCP sets up and runs a bidirectional TCP relay.
|
||||
// The caller owns conn and must close it if this method returns an error.
|
||||
// On success (nil error), both conn and backend are closed by the relay.
|
||||
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
|
||||
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
|
||||
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
|
||||
r.logL4Deny(route, conn, verdict)
|
||||
return errAccessRestricted
|
||||
}
|
||||
|
||||
svcCtx, err := r.acquireRelay(ctx, route)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
|
||||
})
|
||||
entry.Debug("TCP relay started")
|
||||
|
||||
idleTimeout := route.SessionIdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = DefaultIdleTimeout
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
@@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
|
||||
return
|
||||
}
|
||||
|
||||
var sourceIP netip.Addr
|
||||
if remote := conn.RemoteAddr(); remote != nil {
|
||||
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
|
||||
sourceIP = ap.Addr().Unmap()
|
||||
}
|
||||
}
|
||||
sourceIP, _ := addrFromConn(conn)
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
@@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
|
||||
})
|
||||
}
|
||||
|
||||
// logL4Deny sends an access log entry for a denied connection.
|
||||
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
|
||||
r.mu.RLock()
|
||||
al := r.accessLog
|
||||
r.mu.RUnlock()
|
||||
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
|
||||
sourceIP, _ := addrFromConn(conn)
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
ServiceID: route.ServiceID,
|
||||
Protocol: route.Protocol,
|
||||
Host: route.Domain,
|
||||
SourceIP: sourceIP,
|
||||
DenyReason: verdict.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// getOrCreateServiceCtxLocked returns the context for a service, creating one
|
||||
// if it doesn't exist yet. The context is a child of the server context.
|
||||
// Must be called with mu held.
|
||||
@@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types
|
||||
r.svcCancels[svcID] = cancel
|
||||
return ctx
|
||||
}
|
||||
|
||||
// addrFromConn extracts a netip.Addr from a connection's remote address.
|
||||
func addrFromConn(conn net.Conn) (netip.Addr, error) {
|
||||
remote := conn.RemoteAddr()
|
||||
if remote == nil {
|
||||
return netip.Addr{}, errors.New("no remote address")
|
||||
}
|
||||
ap, err := netip.ParseAddrPort(remote.String())
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
return ap.Addr().Unmap(), nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
@@ -1668,3 +1669,73 @@ func startEchoPlain(t *testing.T) net.Listener {
|
||||
|
||||
return ln
|
||||
}
|
||||
|
||||
// fakeAddr implements net.Addr with a custom string representation.
|
||||
type fakeAddr string
|
||||
|
||||
func (f fakeAddr) Network() string { return "tcp" }
|
||||
func (f fakeAddr) String() string { return string(f) }
|
||||
|
||||
// fakeConn is a minimal net.Conn with a controllable RemoteAddr.
|
||||
type fakeConn struct {
|
||||
net.Conn
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (f *fakeConn) RemoteAddr() net.Addr { return f.remote }
|
||||
|
||||
func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
route := Route{Filter: filter}
|
||||
|
||||
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
|
||||
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied")
|
||||
}
|
||||
|
||||
func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
route := Route{Filter: filter}
|
||||
|
||||
conn := &fakeConn{remote: nil}
|
||||
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied")
|
||||
}
|
||||
|
||||
func TestCheckRestrictions_AllowedAndDenied(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
route := Route{Filter: filter}
|
||||
|
||||
allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}}
|
||||
assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist")
|
||||
|
||||
denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}}
|
||||
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist")
|
||||
}
|
||||
|
||||
func TestCheckRestrictions_NilFilter(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
route := Route{Filter: nil}
|
||||
|
||||
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
|
||||
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything")
|
||||
}
|
||||
|
||||
func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
route := Route{Filter: filter}
|
||||
|
||||
// net.IPv4() returns a 16-byte v4-in-v6 representation internally.
|
||||
// The restriction check must Unmap it to match the v4 CIDR.
|
||||
conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}}
|
||||
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR")
|
||||
|
||||
// Explicitly v4-mapped-v6 address string.
|
||||
conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")}
|
||||
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR")
|
||||
|
||||
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
|
||||
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
@@ -67,6 +68,8 @@ type Relay struct {
|
||||
dialTimeout time.Duration
|
||||
sessionTTL time.Duration
|
||||
maxSessions int
|
||||
filter *restrict.Filter
|
||||
geo restrict.GeoResolver
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[clientAddr]*session
|
||||
@@ -114,6 +117,10 @@ type RelayConfig struct {
|
||||
SessionTTL time.Duration
|
||||
MaxSessions int
|
||||
AccessLog l4Logger
|
||||
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
|
||||
Filter *restrict.Filter
|
||||
// Geo is the geolocation lookup used for country-based restrictions.
|
||||
Geo restrict.GeoResolver
|
||||
}
|
||||
|
||||
// New creates a UDP relay for the given listener and backend target.
|
||||
@@ -146,6 +153,8 @@ func New(parentCtx context.Context, cfg RelayConfig) *Relay {
|
||||
dialTimeout: dialTimeout,
|
||||
sessionTTL: sessionTTL,
|
||||
maxSessions: maxSessions,
|
||||
filter: cfg.Filter,
|
||||
geo: cfg.Geo,
|
||||
sessions: make(map[clientAddr]*session),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
@@ -166,9 +175,18 @@ func (r *Relay) ServiceID() types.ServiceID {
|
||||
|
||||
// SetObserver sets the session lifecycle observer. Must be called before Serve.
|
||||
func (r *Relay) SetObserver(obs SessionObserver) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.observer = obs
|
||||
}
|
||||
|
||||
// getObserver returns the current session lifecycle observer.
|
||||
func (r *Relay) getObserver() SessionObserver {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.observer
|
||||
}
|
||||
|
||||
// Serve starts the relay loop. It blocks until the context is canceled
|
||||
// or the listener is closed.
|
||||
func (r *Relay) Serve() {
|
||||
@@ -209,8 +227,8 @@ func (r *Relay) Serve() {
|
||||
}
|
||||
sess.bytesIn.Add(int64(nw))
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
|
||||
}
|
||||
r.bufPool.Put(bufp)
|
||||
}
|
||||
@@ -234,6 +252,10 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
return nil, r.ctx.Err()
|
||||
}
|
||||
|
||||
if err := r.checkAccessRestrictions(addr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
|
||||
if sess, ok = r.sessions[key]; ok && sess != nil {
|
||||
@@ -248,16 +270,16 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
|
||||
if len(r.sessions) >= r.maxSessions {
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionRejected(r.accountID)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPSessionRejected(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
|
||||
}
|
||||
|
||||
if !r.sessLimiter.Allow() {
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionRejected(r.accountID)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPSessionRejected(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("session creation rate limited")
|
||||
}
|
||||
@@ -274,8 +296,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
r.mu.Lock()
|
||||
delete(r.sessions, key)
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionDialError(r.accountID)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPSessionDialError(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
|
||||
}
|
||||
@@ -293,8 +315,8 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
r.sessions[key] = sess
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionStarted(r.accountID)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPSessionStarted(r.accountID)
|
||||
}
|
||||
|
||||
r.sessWg.Go(func() {
|
||||
@@ -305,6 +327,21 @@ func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (r *Relay) checkAccessRestrictions(addr net.Addr) error {
|
||||
if r.filter == nil {
|
||||
return nil
|
||||
}
|
||||
clientIP, err := addrFromUDPAddr(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse client address %s for restriction check: %w", addr, err)
|
||||
}
|
||||
if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow {
|
||||
r.logDeny(clientIP, v)
|
||||
return fmt.Errorf("access restricted for %s", addr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// relayBackendToClient reads packets from the backend and writes them
|
||||
// back to the client through the public-facing listener.
|
||||
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
|
||||
@@ -332,8 +369,8 @@ func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
|
||||
}
|
||||
sess.bytesOut.Add(int64(nw))
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -402,9 +439,10 @@ func (r *Relay) cleanupIdleSessions() {
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
obs := r.getObserver()
|
||||
for _, sess := range expired {
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
if obs != nil {
|
||||
obs.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
@@ -429,8 +467,8 @@ func (r *Relay) removeSession(sess *session) {
|
||||
if removed {
|
||||
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
if obs := r.getObserver(); obs != nil {
|
||||
obs.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
@@ -459,6 +497,22 @@ func (r *Relay) logSessionEnd(sess *session) {
|
||||
})
|
||||
}
|
||||
|
||||
// logDeny sends an access log entry for a denied UDP packet.
|
||||
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) {
|
||||
if r.accessLog == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.accessLog.LogL4(accesslog.L4Entry{
|
||||
AccountID: r.accountID,
|
||||
ServiceID: r.serviceID,
|
||||
Protocol: accesslog.ProtocolUDP,
|
||||
Host: r.domain,
|
||||
SourceIP: clientIP,
|
||||
DenyReason: verdict.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Close stops the relay, waits for all session goroutines to exit,
|
||||
// and cleans up remaining sessions.
|
||||
func (r *Relay) Close() {
|
||||
@@ -485,12 +539,22 @@ func (r *Relay) Close() {
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
obs := r.getObserver()
|
||||
for _, sess := range closedSessions {
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
if obs != nil {
|
||||
obs.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
|
||||
r.sessWg.Wait()
|
||||
}
|
||||
|
||||
// addrFromUDPAddr extracts a netip.Addr from a net.Addr.
|
||||
func addrFromUDPAddr(addr net.Addr) (netip.Addr, error) {
|
||||
ap, err := netip.ParseAddrPort(addr.String())
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
return ap.Addr().Unmap(), nil
|
||||
}
|
||||
|
||||
@@ -490,7 +490,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.WarnLevel)
|
||||
|
||||
authMw := auth.NewMiddleware(logger, nil)
|
||||
authMw := auth.NewMiddleware(logger, nil, nil)
|
||||
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
@@ -511,6 +511,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
0,
|
||||
proxytypes.AccountID(mapping.GetAccountId()),
|
||||
proxytypes.ServiceID(mapping.GetId()),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
188
proxy/server.go
188
proxy/server.go
@@ -43,12 +43,14 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
@@ -59,7 +61,6 @@ import (
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
|
||||
// portRouter bundles a per-port Router with its listener and cancel func.
|
||||
type portRouter struct {
|
||||
router *nbtcp.Router
|
||||
@@ -95,6 +96,9 @@ type Server struct {
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
// does not handle them.
|
||||
hijackTracker conntrack.HijackTracker
|
||||
// geo resolves IP addresses to country/city for access restrictions and access logs.
|
||||
geo restrict.GeoResolver
|
||||
geoRaw *geolocation.Lookup
|
||||
|
||||
// routerReady is closed once mainRouter is fully initialized.
|
||||
// The mapping worker waits on this before processing updates.
|
||||
@@ -159,10 +163,38 @@ type Server struct {
|
||||
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
||||
// ports for TCP/UDP/TLS services.
|
||||
SupportsCustomPorts bool
|
||||
// DefaultDialTimeout is the default timeout for establishing backend
|
||||
// connections when no per-service timeout is configured. Zero means
|
||||
// each transport uses its own hardcoded default (typically 30s).
|
||||
DefaultDialTimeout time.Duration
|
||||
// MaxDialTimeout caps the per-service backend dial timeout.
|
||||
// When the API sends a timeout, it is clamped to this value.
|
||||
// When the API sends no timeout, this value is used as the default.
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
MaxDialTimeout time.Duration
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files for
|
||||
// country-based access restrictions. Empty disables geo lookups.
|
||||
GeoDataDir string
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
|
||||
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
|
||||
return s.MaxSessionIdleTimeout
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
|
||||
// If d is zero, MaxDialTimeout is used as the default.
|
||||
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
|
||||
if s.MaxDialTimeout <= 0 {
|
||||
return d
|
||||
}
|
||||
if d <= 0 || d > s.MaxDialTimeout {
|
||||
return s.MaxDialTimeout
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity.
|
||||
@@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
defer runCancel()
|
||||
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
@@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
|
||||
// Create health checker before the mapping worker so it can track
|
||||
// management connectivity from the first stream connection.
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
|
||||
|
||||
tlsConfig, err := s.configureTLS(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
|
||||
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize geolocation: %w", err)
|
||||
}
|
||||
s.geoRaw = geoLookup
|
||||
if geoLookup != nil {
|
||||
s.geo = geoLookup
|
||||
}
|
||||
|
||||
var startupOK bool
|
||||
defer func() {
|
||||
if startupOK {
|
||||
return
|
||||
}
|
||||
if s.geoRaw != nil {
|
||||
if err := s.geoRaw.Close(); err != nil {
|
||||
s.Logger.Debugf("close geolocation on startup failure: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.startDebugEndpoint()
|
||||
|
||||
if err := s.startHealthServer(); err != nil {
|
||||
@@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
|
||||
startupOK = true
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
|
||||
@@ -691,6 +749,16 @@ func (s *Server) shutdownServices() {
|
||||
s.portRouterWg.Wait()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if s.accessLog != nil {
|
||||
s.accessLog.Close()
|
||||
}
|
||||
|
||||
if s.geoRaw != nil {
|
||||
if err := s.geoRaw.Close(); err != nil {
|
||||
s.Logger.Debugf("close geolocation: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveDialFunc returns a DialContextFunc that dials through the
|
||||
@@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
return fmt.Errorf("router for TCP port %d: %w", port, err)
|
||||
}
|
||||
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
router.SetGeo(s.geo)
|
||||
router.SetFallback(nbtcp.Route{
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTCP,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTCP,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
Filter: parseRestrictions(mapping),
|
||||
})
|
||||
|
||||
s.portMu.Lock()
|
||||
@@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
return fmt.Errorf("empty target address for UDP service %s", svcID)
|
||||
}
|
||||
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
|
||||
return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
|
||||
}
|
||||
@@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
|
||||
}
|
||||
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
router.SetGeo(s.geo)
|
||||
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTLS,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTLS,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
Filter: parseRestrictions(mapping),
|
||||
})
|
||||
|
||||
if tlsPort != s.mainPort {
|
||||
@@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
|
||||
}
|
||||
}
|
||||
|
||||
// parseRestrictions converts a proto mapping's access restrictions into
|
||||
// a restrict.Filter. Returns nil if the mapping has no restrictions.
|
||||
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
|
||||
r := mapping.GetAccessRestrictions()
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
|
||||
}
|
||||
|
||||
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
|
||||
// but the proxy has no geolocation database loaded. All requests to this
|
||||
// service will be denied at runtime (fail-close).
|
||||
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
|
||||
return
|
||||
}
|
||||
if s.geo != nil && s.geo.Available() {
|
||||
return
|
||||
}
|
||||
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
|
||||
}
|
||||
|
||||
// l4TargetAddress extracts and validates the target address from a mapping's
|
||||
// first path entry. Returns empty string if no paths exist or the address is
|
||||
// not a valid host:port.
|
||||
@@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
|
||||
}
|
||||
|
||||
// l4DialTimeout returns the dial timeout from the first target's options,
|
||||
// falling back to the server's DefaultDialTimeout.
|
||||
// clamped to MaxDialTimeout.
|
||||
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
|
||||
paths := mapping.GetPath()
|
||||
if len(paths) > 0 {
|
||||
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
|
||||
return d.AsDuration()
|
||||
return s.clampDialTimeout(d.AsDuration())
|
||||
}
|
||||
}
|
||||
return s.DefaultDialTimeout
|
||||
return s.clampDialTimeout(0)
|
||||
}
|
||||
|
||||
// l4SessionIdleTimeout returns the configured session idle timeout from the
|
||||
@@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
|
||||
|
||||
dialFn, err := s.resolveDialFunc(accountID)
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
if err := listener.Close(); err != nil {
|
||||
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
|
||||
}
|
||||
return fmt.Errorf("resolve dialer for UDP: %w", err)
|
||||
}
|
||||
|
||||
@@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
|
||||
ServiceID: svcID,
|
||||
DialFunc: dialFn,
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionTTL: l4SessionIdleTimeout(mapping),
|
||||
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
AccessLog: s.accessLog,
|
||||
Filter: parseRestrictions(mapping),
|
||||
Geo: s.geo,
|
||||
})
|
||||
relay.SetObserver(s.meter)
|
||||
|
||||
@@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
if mapping.GetAuth().GetOidc() {
|
||||
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
|
||||
}
|
||||
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
|
||||
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
|
||||
}
|
||||
|
||||
ipRestrictions := parseRestrictions(mapping)
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil {
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
|
||||
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
|
||||
}
|
||||
m := s.protoToMapping(ctx, mapping)
|
||||
@@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
|
||||
pt.RequestTimeout = d.AsDuration()
|
||||
}
|
||||
}
|
||||
if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 {
|
||||
pt.RequestTimeout = s.DefaultDialTimeout
|
||||
}
|
||||
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
|
||||
paths[pathMapping.GetPath()] = pt
|
||||
}
|
||||
return proxy.Mapping{
|
||||
m := proxy.Mapping{
|
||||
ID: types.ServiceID(mapping.GetId()),
|
||||
AccountID: types.AccountID(mapping.GetAccountId()),
|
||||
Host: mapping.GetDomain(),
|
||||
@@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
|
||||
PassHostHeader: mapping.GetPassHostHeader(),
|
||||
RewriteRedirects: mapping.GetRewriteRedirects(),
|
||||
}
|
||||
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
|
||||
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
|
||||
|
||||
Reference in New Issue
Block a user