// Package proxy runs a NetBird proxy server. // It attempts to do everything it needs to do within the context // of a single request to the server to try to reduce the amount // of concurrency coordination that is required. However, it does // run two additional routines in an error group for handling // updates from the management server and running a separate // HTTP server to handle ACME HTTP-01 challenges (if configured). package proxy import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "net" "net/http" "net/netip" "net/url" "path/filepath" "reflect" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/pires/go-proxyproto" prometheus2 "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/sdk/metric" "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/conntrack" "github.com/netbirdio/netbird/proxy/internal/debug" "github.com/netbirdio/netbird/proxy/internal/geolocation" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/k8s" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" "github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/roundtrip" nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" "github.com/netbirdio/netbird/proxy/internal/types" udprelay "github.com/netbirdio/netbird/proxy/internal/udp" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/embeddedroots" ) // portRouter bundles a per-port Router with its listener and cancel func. type portRouter struct { router *nbtcp.Router listener net.Listener cancel context.CancelFunc } type Server struct { mgmtClient proto.ProxyServiceClient proxy *proxy.ReverseProxy netbird *roundtrip.NetBird acme *acme.Manager auth *auth.Middleware http *http.Server https *http.Server debug *http.Server healthServer *health.Server healthChecker *health.Checker meter *proxymetrics.Metrics accessLog *accesslog.Logger mainRouter *nbtcp.Router mainPort uint16 udpMu sync.Mutex udpRelays map[types.ServiceID]*udprelay.Relay udpRelayWg sync.WaitGroup portMu sync.RWMutex portRouters map[uint16]*portRouter svcPorts map[types.ServiceID][]uint16 lastMappings map[types.ServiceID]*proto.ProxyMapping portRouterWg sync.WaitGroup // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // so they can be closed during graceful shutdown, since http.Server.Shutdown // does not handle them. hijackTracker conntrack.HijackTracker // geo resolves IP addresses to country/city for access restrictions and access logs. geo restrict.GeoResolver geoRaw *geolocation.Lookup // routerReady is closed once mainRouter is fully initialized. // The mapping worker waits on this before processing updates. routerReady chan struct{} // Mostly used for debugging on management. startTime time.Time ID string Logger *log.Logger Version string ProxyURL string ManagementAddress string CertificateDirectory string CertificateFile string CertificateKeyFile string GenerateACMECertificates bool ACMEChallengeAddress string ACMEDirectory string // ACMEEABKID is the External Account Binding Key ID for CAs that require EAB (e.g., ZeroSSL). ACMEEABKID string // ACMEEABHMACKey is the External Account Binding HMAC key (base64 URL-encoded) for CAs that require EAB. ACMEEABHMACKey string // ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01". // Defaults to "tls-alpn-01" if not specified. ACMEChallengeType string // CertLockMethod controls how ACME certificate locks are coordinated // across replicas. Default: CertLockAuto (detect environment). CertLockMethod acme.CertLockMethod // WildcardCertDir is an optional directory containing wildcard certificate // pairs (.crt / .key). Wildcard patterns are extracted from // the certificates' SAN lists. Matching domains use these static certs // instead of ACME. WildcardCertDir string // DebugEndpointEnabled enables the debug HTTP endpoint. DebugEndpointEnabled bool // DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444"). DebugEndpointAddress string // HealthAddress is the address for the health probe endpoint. HealthAddress string // ProxyToken is the access token for authenticating with the management server. ProxyToken string // ForwardedProto overrides the X-Forwarded-Proto value sent to backends. // Valid values: "auto" (detect from TLS), "http", "https". ForwardedProto string // TrustedProxies is a list of IP prefixes for trusted upstream proxies. // When set, forwarding headers from these sources are preserved and // appended to instead of being stripped. TrustedProxies []netip.Prefix // WireguardPort is the port for the NetBird tunnel interface. Use 0 // for a random OS-assigned port. A fixed port only works with // single-account deployments; multiple accounts will fail to bind // the same port. WireguardPort uint16 // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool // PreSharedKey used for tunnel between proxy and peers (set globally not per account) PreSharedKey string // SupportsCustomPorts indicates whether the proxy can bind arbitrary // ports for TCP/UDP/TLS services. SupportsCustomPorts bool // RequireSubdomain indicates whether a subdomain label is required // in front of this proxy's cluster domain. When true, accounts cannot // create services on the bare cluster domain. RequireSubdomain bool // MaxDialTimeout caps the per-service backend dial timeout. // When the API sends a timeout, it is clamped to this value. // When the API sends no timeout, this value is used as the default. // Zero means no cap (the proxy honors whatever management sends). MaxDialTimeout time.Duration // GeoDataDir is the directory containing GeoLite2 MMDB files for // country-based access restrictions. Empty disables geo lookups. GeoDataDir string // MaxSessionIdleTimeout caps the per-service session idle timeout. // Zero means no cap (the proxy honors whatever management sends). // Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments. MaxSessionIdleTimeout time.Duration } // clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured. func (s *Server) clampIdleTimeout(d time.Duration) time.Duration { if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout { return s.MaxSessionIdleTimeout } return d } // clampDialTimeout returns d capped to MaxDialTimeout when configured. // If d is zero, MaxDialTimeout is used as the default. func (s *Server) clampDialTimeout(d time.Duration) time.Duration { if s.MaxDialTimeout <= 0 { return d } if d <= 0 || d > s.MaxDialTimeout { return s.MaxDialTimeout } return d } // NotifyStatus sends a status update to management about tunnel connectivity. func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error { status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED if connected { status = proto.ProxyStatus_PROXY_STATUS_ACTIVE } _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ ServiceId: string(serviceID), AccountId: string(accountID), Status: status, CertificateIssued: false, }) return err } // NotifyCertificateIssued sends a notification to management that a certificate was issued func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error { _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ ServiceId: string(serviceID), AccountId: string(accountID), Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE, CertificateIssued: true, }) return err } func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.initDefaults() s.routerReady = make(chan struct{}) s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) s.portRouters = make(map[uint16]*portRouter) s.svcPorts = make(map[types.ServiceID][]uint16) s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) exporter, err := prometheus.New() if err != nil { return fmt.Errorf("create prometheus exporter: %w", err) } provider := metric.NewMeterProvider(metric.WithReader(exporter)) pkg := reflect.TypeOf(Server{}).PkgPath() meter := provider.Meter(pkg) s.meter, err = proxymetrics.New(ctx, meter) if err != nil { return fmt.Errorf("create metrics: %w", err) } mgmtConn, err := s.dialManagement() if err != nil { return err } defer func() { if err := mgmtConn.Close(); err != nil { s.Logger.Debugf("management connection close: %v", err) } }() s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) runCtx, runCancel := context.WithCancel(ctx) defer runCancel() // Initialize the netbird client, this is required to build peer connections // to proxy over. s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{ MgmtAddr: s.ManagementAddress, WGPort: s.WireguardPort, PreSharedKey: s.PreSharedKey, }, s.Logger, s, s.mgmtClient) // Create health checker before the mapping worker so it can track // management connectivity from the first stream connection. s.healthChecker = health.NewChecker(s.Logger, s.netbird) go s.newManagementMappingWorker(runCtx, s.mgmtClient) tlsConfig, err := s.configureTLS(ctx) if err != nil { return err } // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir) if err != nil { return fmt.Errorf("initialize geolocation: %w", err) } s.geoRaw = geoLookup if geoLookup != nil { s.geo = geoLookup } var startupOK bool defer func() { if startupOK { return } if s.geoRaw != nil { if err := s.geoRaw.Close(); err != nil { s.Logger.Debugf("close geolocation on startup failure: %v", err) } } }() // Configure the authentication middleware with session validator for OIDC group checks. s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo) // Configure Access logs to management server. s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.startDebugEndpoint() if err := s.startHealthServer(); err != nil { return err } // Build the handler chain from inside out. handler := http.Handler(s.proxy) handler = s.auth.Protect(handler) handler = web.AssetHandler(handler) handler = s.accessLog.Middleware(handler) handler = s.meter.Middleware(handler) handler = s.hijackTracker.Middleware(handler) // Start a raw TCP listener; the SNI router peeks at ClientHello // and routes to either the HTTP handler or a TCP relay. lc := net.ListenConfig{} ln, err := lc.Listen(ctx, "tcp", addr) if err != nil { return fmt.Errorf("listen on %s: %w", addr, err) } if s.ProxyProtocol { ln = s.wrapProxyProtocol(ln) } s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid // Set up the SNI router for TCP/HTTP multiplexing on the main port. s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr()) s.mainRouter.SetObserver(s.meter) s.mainRouter.SetAccessLogger(s.accessLog) close(s.routerReady) // The HTTP server uses the chanListener fed by the SNI router. s.https = &http.Server{ Addr: addr, Handler: handler, TLSConfig: tlsConfig, ReadHeaderTimeout: httpReadHeaderTimeout, IdleTimeout: httpIdleTimeout, ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), } startupOK = true httpsErr := make(chan error, 1) go func() { s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "") }() routerErr := make(chan error, 1) go func() { s.Logger.Debugf("starting SNI router on %s", addr) routerErr <- s.mainRouter.Serve(runCtx, ln) }() select { case err := <-httpsErr: s.shutdownServices() if !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("https server: %w", err) } return nil case err := <-routerErr: s.shutdownServices() if err != nil { return fmt.Errorf("SNI router: %w", err) } return nil case <-ctx.Done(): s.gracefulShutdown() return nil } } // initDefaults sets fallback values for optional Server fields. func (s *Server) initDefaults() { s.startTime = time.Now() // If no ID is set then one can be generated. if s.ID == "" { s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405") } // Fallback version option in case it is not set. if s.Version == "" { s.Version = "dev" } // If no logger is specified fallback to the standard logger. if s.Logger == nil { s.Logger = log.StandardLogger() } } // startDebugEndpoint launches the debug HTTP server if enabled. func (s *Server) startDebugEndpoint() { if !s.DebugEndpointEnabled { return } debugAddr := debugEndpointAddr(s.DebugEndpointAddress) debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger) if s.acme != nil { debugHandler.SetCertStatus(s.acme) } s.debug = &http.Server{ Addr: debugAddr, Handler: debugHandler, ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug), } go func() { s.Logger.Infof("starting debug endpoint on %s", debugAddr) if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { s.Logger.Errorf("debug endpoint error: %v", err) } }() } // startHealthServer launches the health probe and metrics server. func (s *Server) startHealthServer() error { healthAddr := s.HealthAddress if healthAddr == "" { healthAddr = defaultHealthAddr } s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true})) healthListener, err := net.Listen("tcp", healthAddr) if err != nil { return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err) } go func() { if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) { s.Logger.Errorf("health probe server: %v", err) } }() return nil } // wrapProxyProtocol wraps a listener with PROXY protocol support. // When TrustedProxies is configured, only those sources may send PROXY headers; // connections from untrusted sources have any PROXY header ignored. func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener { ppListener := &proxyproto.Listener{ Listener: ln, ReadHeaderTimeout: proxyProtoHeaderTimeout, } if len(s.TrustedProxies) > 0 { ppListener.ConnPolicy = s.proxyProtocolPolicy } else { s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers") } s.Logger.Info("PROXY protocol enabled on listener") return ppListener } // proxyProtocolPolicy returns whether to require, skip, or reject the PROXY // header based on whether the connection source is in TrustedProxies. func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { // No logging on reject to prevent abuse tcpAddr, ok := opts.Upstream.(*net.TCPAddr) if !ok { return proxyproto.REJECT, nil } addr, ok := netip.AddrFromSlice(tcpAddr.IP) if !ok { return proxyproto.REJECT, nil } addr = addr.Unmap() // called per accept for _, prefix := range s.TrustedProxies { if prefix.Contains(addr) { return proxyproto.REQUIRE, nil } } return proxyproto.IGNORE, nil } const ( defaultHealthAddr = "localhost:8080" defaultDebugAddr = "localhost:8444" // proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol // header after accepting a connection. proxyProtoHeaderTimeout = 5 * time.Second // shutdownPreStopDelay is the time to wait after receiving a shutdown signal // before draining connections. This allows the load balancer to propagate // the endpoint removal. shutdownPreStopDelay = 5 * time.Second // shutdownDrainTimeout is the maximum time to wait for in-flight HTTP // requests to complete during graceful shutdown. shutdownDrainTimeout = 30 * time.Second // shutdownServiceTimeout is the maximum time to wait for auxiliary // services (health probe, debug endpoint, ACME) to shut down. shutdownServiceTimeout = 5 * time.Second // httpReadHeaderTimeout limits how long the server waits to read // request headers after accepting a connection. Prevents slowloris. httpReadHeaderTimeout = 10 * time.Second // httpIdleTimeout limits how long an idle keep-alive connection // stays open before the server closes it. httpIdleTimeout = 120 * time.Second ) func (s *Server) dialManagement() (*grpc.ClientConn, error) { mgmtURL, err := url.Parse(s.ManagementAddress) if err != nil { return nil, fmt.Errorf("parse management address: %w", err) } creds := insecure.NewCredentials() // Assume management TLS is enabled for gRPC as well if using HTTPS for the API. if mgmtURL.Scheme == "https" { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { // Fall back to embedded CAs if no OS-provided ones are available. certPool = embeddedroots.Get() } creds = credentials.NewTLS(&tls.Config{ RootCAs: certPool, }) } s.Logger.WithFields(log.Fields{ "gRPC_address": mgmtURL.Host, "TLS_enabled": mgmtURL.Scheme == "https", }).Debug("starting management gRPC client") conn, err := grpc.NewClient(mgmtURL.Host, grpc.WithTransportCredentials(creds), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 20 * time.Second, Timeout: 10 * time.Second, PermitWithoutStream: true, }), proxygrpc.WithProxyToken(s.ProxyToken), ) if err != nil { return nil, fmt.Errorf("create management connection: %w", err) } return conn, nil } func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { tlsConfig := &tls.Config{} if !s.GenerateACMECertificates { s.Logger.Debug("ACME certificates disabled, using static certificates with file watching") certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile) keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile) certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger) if err != nil { return nil, fmt.Errorf("initialize certificate watcher: %w", err) } go certWatcher.Watch(ctx) tlsConfig.GetCertificate = certWatcher.GetCertificate return tlsConfig, nil } if s.ACMEChallengeType == "" { s.ACMEChallengeType = "tls-alpn-01" } s.Logger.WithFields(log.Fields{ "acme_server": s.ACMEDirectory, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificates enabled, configuring certificate manager") var err error s.acme, err = acme.NewManager(acme.ManagerConfig{ CertDir: s.CertificateDirectory, ACMEURL: s.ACMEDirectory, EABKID: s.ACMEEABKID, EABHMACKey: s.ACMEEABHMACKey, LockMethod: s.CertLockMethod, WildcardDir: s.WildcardCertDir, }, s, s.Logger, s.meter) if err != nil { return nil, fmt.Errorf("create ACME manager: %w", err) } go s.acme.WatchWildcards(ctx) if s.ACMEChallengeType == "http-01" { s.http = &http.Server{ Addr: s.ACMEChallengeAddress, Handler: s.acme.HTTPHandler(nil), ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME), } go func() { if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed") } }() } tlsConfig = s.acme.TLSConfig() // autocert.Manager.TLSConfig() wires its own GetCertificate, which // bypasses our override that checks wildcards first. tlsConfig.GetCertificate = s.acme.GetCertificate // ServerName needs to be set to allow for ACME to work correctly // when using CNAME URLs to access the proxy. tlsConfig.ServerName = s.ProxyURL s.Logger.WithFields(log.Fields{ "ServerName": s.ProxyURL, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificate manager configured") return tlsConfig, nil } // gracefulShutdown performs a zero-downtime shutdown sequence. It marks the // readiness probe as failing, waits for load balancer propagation, drains // in-flight connections, and then stops all background services. func (s *Server) gracefulShutdown() { s.Logger.Info("shutdown signal received, starting graceful shutdown") // Step 1: Fail readiness probe so load balancers stop routing new traffic. if s.healthChecker != nil { s.healthChecker.SetShuttingDown() } // Step 2: When running behind a load balancer, wait for endpoint removal // to propagate before draining connections. if k8s.InCluster() { s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay) time.Sleep(shutdownPreStopDelay) } // Step 3: Stop accepting new connections and drain in-flight requests. drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout) defer drainCancel() s.Logger.Info("draining in-flight connections") if err := s.https.Shutdown(drainCtx); err != nil { s.Logger.Warnf("https server drain: %v", err) } // Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle. if n := s.hijackTracker.CloseAll(); n > 0 { s.Logger.Infof("closed %d hijacked connection(s)", n) } // Drain all router relay connections (main + per-port) in parallel. s.drainAllRouters(shutdownDrainTimeout) // Step 5: Stop all remaining background services. s.shutdownServices() s.Logger.Info("graceful shutdown complete") } // shutdownServices stops all background services concurrently and waits for // them to finish. // drainAllRouters drains active relay connections on the main router and // all per-port routers in parallel, up to the given timeout. func (s *Server) drainAllRouters(timeout time.Duration) { var wg sync.WaitGroup drain := func(name string, router *nbtcp.Router) { wg.Add(1) go func() { defer wg.Done() if ok := router.Drain(timeout); !ok { s.Logger.Warnf("timed out draining %s relay connections", name) } }() } if s.mainRouter != nil { drain("main router", s.mainRouter) } s.portMu.RLock() for port, pr := range s.portRouters { drain(fmt.Sprintf("port %d", port), pr.router) } s.portMu.RUnlock() wg.Wait() } func (s *Server) shutdownServices() { var wg sync.WaitGroup shutdownHTTP := func(name string, shutdown func(context.Context) error) { wg.Add(1) go func() { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout) defer cancel() if err := shutdown(ctx); err != nil { s.Logger.Debugf("%s shutdown: %v", name, err) } }() } if s.healthServer != nil { shutdownHTTP("health probe", s.healthServer.Shutdown) } if s.debug != nil { shutdownHTTP("debug endpoint", s.debug.Shutdown) } if s.http != nil { shutdownHTTP("acme http", s.http.Shutdown) } if s.netbird != nil { wg.Add(1) go func() { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout) defer cancel() if err := s.netbird.StopAll(ctx); err != nil { s.Logger.Warnf("stop netbird clients: %v", err) } }() } // Close all UDP relays and wait for their goroutines to exit. s.udpMu.Lock() for id, relay := range s.udpRelays { relay.Close() delete(s.udpRelays, id) } s.udpMu.Unlock() s.udpRelayWg.Wait() // Close all per-port routers. s.portMu.Lock() for port, pr := range s.portRouters { pr.cancel() if err := pr.listener.Close(); err != nil { s.Logger.Debugf("close listener on port %d: %v", port, err) } delete(s.portRouters, port) } maps.Clear(s.svcPorts) maps.Clear(s.lastMappings) s.portMu.Unlock() // Wait for per-port router serve goroutines to exit. s.portRouterWg.Wait() wg.Wait() if s.accessLog != nil { s.accessLog.Close() } if s.geoRaw != nil { if err := s.geoRaw.Close(); err != nil { s.Logger.Debugf("close geolocation: %v", err) } } } // resolveDialFunc returns a DialContextFunc that dials through the // NetBird tunnel for the given account. func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFunc, error) { client, ok := s.netbird.GetClient(accountID) if !ok { return nil, fmt.Errorf("no client for account %s", accountID) } return client.DialContext, nil } // notifyError reports a resource error back to management so it can be // surfaced to the user (e.g. port bind failure, dialer resolution error). func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) { s.sendStatusUpdate(ctx, types.AccountID(mapping.GetAccountId()), types.ServiceID(mapping.GetId()), proto.ProxyStatus_PROXY_STATUS_ERROR, err) } // sendStatusUpdate sends a status update for a service to management. func (s *Server) sendStatusUpdate(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, st proto.ProxyStatus, err error) { req := &proto.SendStatusUpdateRequest{ ServiceId: string(serviceID), AccountId: string(accountID), Status: st, } if err != nil { msg := err.Error() req.ErrorMessage = &msg } if _, sendErr := s.mgmtClient.SendStatusUpdate(ctx, req); sendErr != nil { s.Logger.Debugf("failed to send status update for %s: %v", serviceID, sendErr) } } // routerForPort returns the router that handles the given listen port. If port // is 0 or matches the main listener port, the main router is returned. // Otherwise a new per-port router is created and started. func (s *Server) routerForPort(ctx context.Context, port uint16) (*nbtcp.Router, error) { if port == 0 || port == s.mainPort { return s.mainRouter, nil } return s.getOrCreatePortRouter(ctx, port) } // routerForPortExisting returns the router for the given port without creating // one. Returns the main router for port 0 / mainPort, or nil if no per-port // router exists. func (s *Server) routerForPortExisting(port uint16) *nbtcp.Router { if port == 0 || port == s.mainPort { return s.mainRouter } s.portMu.RLock() pr := s.portRouters[port] s.portMu.RUnlock() if pr != nil { return pr.router } return nil } // getOrCreatePortRouter returns an existing per-port router or creates one // with a new TCP listener and starts serving. func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp.Router, error) { s.portMu.Lock() defer s.portMu.Unlock() if pr, ok := s.portRouters[port]; ok { return pr.router, nil } listenAddr := fmt.Sprintf(":%d", port) ln, err := net.Listen("tcp", listenAddr) if err != nil { return nil, fmt.Errorf("listen TCP on %s: %w", listenAddr, err) } if s.ProxyProtocol { ln = s.wrapProxyProtocol(ln) } router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc) router.SetObserver(s.meter) router.SetAccessLogger(s.accessLog) portCtx, cancel := context.WithCancel(ctx) s.portRouters[port] = &portRouter{ router: router, listener: ln, cancel: cancel, } s.portRouterWg.Add(1) go func() { defer s.portRouterWg.Done() if err := router.Serve(portCtx, ln); err != nil { s.Logger.Debugf("port %d router stopped: %v", port, err) } }() s.Logger.Debugf("started per-port router on %s", listenAddr) return router, nil } // cleanupPortIfEmpty tears down a per-port router if it has no remaining // routes or fallback. The main port is never cleaned up. Active relay // connections are drained before the listener is closed. func (s *Server) cleanupPortIfEmpty(port uint16) { if port == 0 || port == s.mainPort { return } s.portMu.Lock() pr, ok := s.portRouters[port] if !ok || !pr.router.IsEmpty() { s.portMu.Unlock() return } // Cancel and close the listener while holding the lock so that // getOrCreatePortRouter sees the entry is gone before we drain. pr.cancel() if err := pr.listener.Close(); err != nil { s.Logger.Debugf("close listener on port %d: %v", port, err) } delete(s.portRouters, port) s.portMu.Unlock() // Drain active relay connections outside the lock. if ok := pr.router.Drain(nbtcp.DefaultDrainTimeout); !ok { s.Logger.Warnf("timed out draining relay connections on port %d", port) } s.Logger.Debugf("cleaned up empty per-port router on port %d", port) } func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { bo := &backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 1, Multiplier: 1.7, MaxInterval: 10 * time.Second, MaxElapsedTime: 0, // retry indefinitely until context is canceled Stop: backoff.Stop, Clock: backoff.SystemClock, } initialSyncDone := false operation := func() error { s.Logger.Debug("connecting to management mapping stream") if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ ProxyId: s.ID, Version: s.Version, StartedAt: timestamppb.New(s.startTime), Address: s.ProxyURL, Capabilities: &proto.ProxyCapabilities{ SupportsCustomPorts: &s.SupportsCustomPorts, RequireSubdomain: &s.RequireSubdomain, }, }) if err != nil { return fmt.Errorf("create mapping stream: %w", err) } if s.healthChecker != nil { s.healthChecker.SetManagementConnected(true) } s.Logger.Debug("management mapping stream established") // Stream established — reset backoff so the next failure retries quickly. bo.Reset() streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone) if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } if streamErr == nil { return fmt.Errorf("stream closed by server") } return fmt.Errorf("mapping stream: %w", streamErr) } notify := func(err error, next time.Duration) { s.Logger.Warnf("management connection failed, retrying in %s: %v", next.Truncate(time.Millisecond), err) } if err := backoff.RetryNotify(operation, backoff.WithContext(bo, ctx), notify); err != nil { s.Logger.WithError(err).Debug("management mapping worker exiting") } } func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { select { case <-s.routerReady: case <-ctx.Done(): return ctx.Err() } for { // Check for context completion to gracefully shutdown. select { case <-ctx.Done(): // Shutting down. return ctx.Err() default: msg, err := mappingClient.Recv() switch { case errors.Is(err, io.EOF): // Mapping connection gracefully terminated by server. return nil case err != nil: // Something has gone horribly wrong, return and hope the parent retries the connection. return fmt.Errorf("receive msg: %w", err) } s.Logger.Debug("Received mapping update, starting processing") s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") if !*initialSyncDone && msg.GetInitialSyncComplete() { if s.healthChecker != nil { s.healthChecker.SetInitialSyncComplete() } *initialSyncDone = true s.Logger.Info("Initial mapping sync complete") } } } } func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { for _, mapping := range mappings { s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), "domain": mapping.GetDomain(), "mode": mapping.GetMode(), "port": mapping.GetListenPort(), "id": mapping.GetId(), }).Debug("Processing mapping update") switch mapping.GetType() { case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, }).Error("Error adding new mapping, ignoring this mapping and continuing processing") s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED: if err := s.modifyMapping(ctx, mapping); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "domain": mapping.GetDomain(), "error": err, }).Error("failed to modify mapping") s.notifyError(ctx, mapping, err) } case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED: s.removeMapping(ctx, mapping) } } } // addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { accountID := types.AccountID(mapping.GetAccountId()) svcID := types.ServiceID(mapping.GetId()) authToken := mapping.GetAuthToken() svcKey := s.serviceKeyForMapping(mapping) if err := s.netbird.AddPeer(ctx, accountID, svcKey, authToken, svcID); err != nil { return fmt.Errorf("create peer for service %s: %w", svcID, err) } if err := s.setupMappingRoutes(ctx, mapping); err != nil { s.cleanupMappingRoutes(mapping) if peerErr := s.netbird.RemovePeer(ctx, accountID, svcKey); peerErr != nil { s.Logger.WithError(peerErr).WithField("service_id", svcID).Warn("failed to remove peer after setup failure") } return err } s.storeMapping(mapping) return nil } // modifyMapping updates a service mapping in place without tearing down the // NetBird peer. It cleans up old routes using the previously stored mapping // state and re-applies them from the new mapping. func (s *Server) modifyMapping(ctx context.Context, mapping *proto.ProxyMapping) error { if old := s.loadMapping(types.ServiceID(mapping.GetId())); old != nil { s.cleanupMappingRoutes(old) if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { s.meter.L4ServiceRemoved(mode) } } else { s.cleanupMappingRoutes(mapping) } if err := s.setupMappingRoutes(ctx, mapping); err != nil { s.cleanupMappingRoutes(mapping) return err } s.storeMapping(mapping) return nil } // setupMappingRoutes configures the appropriate routes or relays for the given // service mapping based on its mode. The NetBird peer must already exist. func (s *Server) setupMappingRoutes(ctx context.Context, mapping *proto.ProxyMapping) error { switch types.ServiceMode(mapping.GetMode()) { case types.ServiceModeTCP: return s.setupTCPMapping(ctx, mapping) case types.ServiceModeUDP: return s.setupUDPMapping(ctx, mapping) case types.ServiceModeTLS: return s.setupTLSMapping(ctx, mapping) default: return s.setupHTTPMapping(ctx, mapping) } } // setupHTTPMapping configures HTTP reverse proxy, auth, and ACME routes. func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { d := domain.Domain(mapping.GetDomain()) accountID := types.AccountID(mapping.GetAccountId()) svcID := types.ServiceID(mapping.GetId()) if len(mapping.GetPath()) == 0 { return nil } var wildcardHit bool if s.acme != nil { wildcardHit = s.acme.AddDomain(d, accountID, svcID) } s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svcID, Domain: mapping.GetDomain(), }) if err := s.updateMapping(ctx, mapping); err != nil { return fmt.Errorf("update mapping for domain %q: %w", d, err) } if wildcardHit { if err := s.NotifyCertificateIssued(ctx, accountID, svcID, string(d)); err != nil { s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err) } } return nil } // setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port. func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { svcID := types.ServiceID(mapping.GetId()) accountID := types.AccountID(mapping.GetAccountId()) port, err := netutil.ValidatePort(mapping.GetListenPort()) if err != nil { return fmt.Errorf("TCP service %s: %w", svcID, err) } targetAddr := s.l4TargetAddress(mapping) if targetAddr == "" { return fmt.Errorf("empty target address for TCP service %s", svcID) } if s.WireguardPort != 0 && port == s.WireguardPort { return fmt.Errorf("port %d conflicts with tunnel port", port) } router, err := s.routerForPort(ctx, port) if err != nil { return fmt.Errorf("router for TCP port %d: %w", port, err) } s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) router.SetGeo(s.geo) router.SetFallback(nbtcp.Route{ Type: nbtcp.RouteTCP, AccountID: accountID, ServiceID: svcID, Domain: mapping.GetDomain(), Protocol: accesslog.ProtocolTCP, Target: targetAddr, ProxyProtocol: s.l4ProxyProtocol(mapping), DialTimeout: s.l4DialTimeout(mapping), SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), Filter: parseRestrictions(mapping), }) s.portMu.Lock() s.svcPorts[svcID] = []uint16{port} s.portMu.Unlock() s.meter.L4ServiceAdded(types.ServiceModeTCP) s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) return nil } // setupUDPMapping starts a UDP relay on the listen port. func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMapping) error { svcID := types.ServiceID(mapping.GetId()) accountID := types.AccountID(mapping.GetAccountId()) port, err := netutil.ValidatePort(mapping.GetListenPort()) if err != nil { return fmt.Errorf("UDP service %s: %w", svcID, err) } targetAddr := s.l4TargetAddress(mapping) if targetAddr == "" { return fmt.Errorf("empty target address for UDP service %s", svcID) } s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil { return fmt.Errorf("UDP relay for service %s: %w", svcID, err) } s.meter.L4ServiceAdded(types.ServiceModeUDP) s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) return nil } // setupTLSMapping configures a TLS SNI-routed passthrough on the listen port. func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMapping) error { svcID := types.ServiceID(mapping.GetId()) accountID := types.AccountID(mapping.GetAccountId()) tlsPort, err := netutil.ValidatePort(mapping.GetListenPort()) if err != nil { return fmt.Errorf("TLS service %s: %w", svcID, err) } targetAddr := s.l4TargetAddress(mapping) if targetAddr == "" { return fmt.Errorf("empty target address for TLS service %s", svcID) } if s.WireguardPort != 0 && tlsPort == s.WireguardPort { return fmt.Errorf("port %d conflicts with tunnel port", tlsPort) } router, err := s.routerForPort(ctx, tlsPort) if err != nil { return fmt.Errorf("router for TLS port %d: %w", tlsPort, err) } s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) router.SetGeo(s.geo) router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ Type: nbtcp.RouteTCP, AccountID: accountID, ServiceID: svcID, Domain: mapping.GetDomain(), Protocol: accesslog.ProtocolTLS, Target: targetAddr, ProxyProtocol: s.l4ProxyProtocol(mapping), DialTimeout: s.l4DialTimeout(mapping), SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), Filter: parseRestrictions(mapping), }) if tlsPort != s.mainPort { s.portMu.Lock() s.svcPorts[svcID] = []uint16{tlsPort} s.portMu.Unlock() } s.Logger.WithFields(log.Fields{ "domain": mapping.GetDomain(), "target": targetAddr, "port": tlsPort, "service": svcID, }).Info("TLS passthrough mapping added") s.meter.L4ServiceAdded(types.ServiceModeTLS) s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil) return nil } // serviceKeyForMapping returns the appropriate ServiceKey for a mapping. // TCP/UDP use an ID-based key; HTTP/TLS use a domain-based key. func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.ServiceKey { switch types.ServiceMode(mapping.GetMode()) { case types.ServiceModeTCP, types.ServiceModeUDP: return roundtrip.L4ServiceKey(types.ServiceID(mapping.GetId())) default: return roundtrip.DomainServiceKey(mapping.GetDomain()) } } // parseRestrictions converts a proto mapping's access restrictions into // a restrict.Filter. Returns nil if the mapping has no restrictions. func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter { r := mapping.GetAccessRestrictions() if r == nil { return nil } return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries()) } // warnIfGeoUnavailable logs a warning if the mapping has country restrictions // but the proxy has no geolocation database loaded. All requests to this // service will be denied at runtime (fail-close). func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) { if r == nil { return } if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 { return } if s.geo != nil && s.geo.Available() { return } s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain) } // l4TargetAddress extracts and validates the target address from a mapping's // first path entry. Returns empty string if no paths exist or the address is // not a valid host:port. func (s *Server) l4TargetAddress(mapping *proto.ProxyMapping) string { paths := mapping.GetPath() if len(paths) == 0 { return "" } target := paths[0].GetTarget() if _, _, err := net.SplitHostPort(target); err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "target": target, }).Warnf("invalid L4 target address: %v", err) return "" } return target } // l4ProxyProtocol returns whether the first target has PROXY protocol enabled. func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool { paths := mapping.GetPath() if len(paths) == 0 { return false } return paths[0].GetOptions().GetProxyProtocol() } // l4DialTimeout returns the dial timeout from the first target's options, // clamped to MaxDialTimeout. func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration { paths := mapping.GetPath() if len(paths) > 0 { if d := paths[0].GetOptions().GetRequestTimeout(); d != nil { return s.clampDialTimeout(d.AsDuration()) } } return s.clampDialTimeout(0) } // l4SessionIdleTimeout returns the configured session idle timeout from the // mapping options, or 0 to use the relay's default. func l4SessionIdleTimeout(mapping *proto.ProxyMapping) time.Duration { paths := mapping.GetPath() if len(paths) > 0 { if d := paths[0].GetOptions().GetSessionIdleTimeout(); d != nil { return d.AsDuration() } } return 0 } // addUDPRelay starts a UDP relay on the specified listen port. func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, targetAddress string, listenPort uint16) error { svcID := types.ServiceID(mapping.GetId()) accountID := types.AccountID(mapping.GetAccountId()) if s.WireguardPort != 0 && listenPort == s.WireguardPort { return fmt.Errorf("UDP port %d conflicts with tunnel port", listenPort) } // Close existing relay if present (idempotent re-add). s.removeUDPRelay(svcID) listenAddr := fmt.Sprintf(":%d", listenPort) listener, err := net.ListenPacket("udp", listenAddr) if err != nil { return fmt.Errorf("listen UDP on %s: %w", listenAddr, err) } dialFn, err := s.resolveDialFunc(accountID) if err != nil { if err := listener.Close(); err != nil { s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err) } return fmt.Errorf("resolve dialer for UDP: %w", err) } entry := s.Logger.WithFields(log.Fields{ "target": targetAddress, "listen_port": listenPort, "service_id": svcID, }) relay := udprelay.New(ctx, udprelay.RelayConfig{ Logger: entry, Listener: listener, Target: targetAddress, Domain: mapping.GetDomain(), AccountID: accountID, ServiceID: svcID, DialFunc: dialFn, DialTimeout: s.l4DialTimeout(mapping), SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)), AccessLog: s.accessLog, Filter: parseRestrictions(mapping), Geo: s.geo, }) relay.SetObserver(s.meter) s.udpMu.Lock() s.udpRelays[svcID] = relay s.udpMu.Unlock() s.udpRelayWg.Go(relay.Serve) entry.Info("UDP relay added") return nil } func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) error { // Very simple implementation here, we don't touch the existing peer // connection or any existing TLS configuration, we simply overwrite // the auth and proxy mappings. // Note: this does require the management server to always send a // full mapping rather than deltas during a modification. accountID := types.AccountID(mapping.GetAccountId()) svcID := types.ServiceID(mapping.GetId()) var schemes []auth.Scheme if mapping.GetAuth().GetPassword() { schemes = append(schemes, auth.NewPassword(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetPin() { schemes = append(schemes, auth.NewPin(s.mgmtClient, svcID, accountID)) } if mapping.GetAuth().GetOidc() { schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto)) } for _, ha := range mapping.GetAuth().GetHeaderAuths() { schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader())) } ipRestrictions := parseRestrictions(mapping) s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } m := s.protoToMapping(ctx, mapping) s.proxy.AddMapping(m) s.meter.AddMapping(m) return nil } // removeMapping tears down routes/relays and the NetBird peer for a service. // Uses the stored mapping state when available to ensure all previously // configured routes are cleaned up. func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) { accountID := types.AccountID(mapping.GetAccountId()) svcKey := s.serviceKeyForMapping(mapping) if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil { s.Logger.WithFields(log.Fields{ "account_id": accountID, "service_id": mapping.GetId(), "error": err, }).Error("failed to remove NetBird peer, continuing cleanup") } if old := s.deleteMapping(types.ServiceID(mapping.GetId())); old != nil { s.cleanupMappingRoutes(old) if mode := types.ServiceMode(old.GetMode()); mode.IsL4() { s.meter.L4ServiceRemoved(mode) } } else { s.cleanupMappingRoutes(mapping) } } // cleanupMappingRoutes removes HTTP/TLS/L4 routes and custom port state for a // service without touching the NetBird peer. This is used for both full // removal and in-place modification of mappings. func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { svcID := types.ServiceID(mapping.GetId()) host := mapping.GetDomain() // HTTP/TLS cleanup (only relevant when a domain is set). if host != "" { d := domain.Domain(host) if s.acme != nil { s.acme.RemoveDomain(d) } s.auth.RemoveDomain(host) if s.proxy.RemoveMapping(proxy.Mapping{Host: host}) { s.meter.RemoveMapping(proxy.Mapping{Host: host}) } // Close hijacked connections (WebSocket) for this domain. if n := s.hijackTracker.CloseByHost(host); n > 0 { s.Logger.Debugf("closed %d hijacked connection(s) for %s", n, host) } // Remove SNI route from the main router (covers both HTTP and main-port TLS). s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID) } // Extract and delete tracked custom-port entries atomically. s.portMu.Lock() entries := s.svcPorts[svcID] delete(s.svcPorts, svcID) s.portMu.Unlock() for _, entry := range entries { if router := s.routerForPortExisting(entry); router != nil { if host != "" { router.RemoveRoute(nbtcp.SNIHost(host), svcID) } else { router.RemoveFallback(svcID) } } s.cleanupPortIfEmpty(entry) } // UDP relay cleanup (idempotent). s.removeUDPRelay(svcID) } // removeUDPRelay stops and removes a UDP relay by service ID. func (s *Server) removeUDPRelay(svcID types.ServiceID) { s.udpMu.Lock() relay, ok := s.udpRelays[svcID] if ok { delete(s.udpRelays, svcID) } s.udpMu.Unlock() if ok { relay.Close() s.Logger.WithField("service_id", svcID).Info("UDP relay removed") } } func (s *Server) storeMapping(mapping *proto.ProxyMapping) { s.portMu.Lock() s.lastMappings[types.ServiceID(mapping.GetId())] = mapping s.portMu.Unlock() } func (s *Server) loadMapping(svcID types.ServiceID) *proto.ProxyMapping { s.portMu.RLock() m := s.lastMappings[svcID] s.portMu.RUnlock() return m } func (s *Server) deleteMapping(svcID types.ServiceID) *proto.ProxyMapping { s.portMu.Lock() m := s.lastMappings[svcID] delete(s.lastMappings, svcID) s.portMu.Unlock() return m } func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping) proxy.Mapping { paths := make(map[string]*proxy.PathTarget) for _, pathMapping := range mapping.GetPath() { targetURL, err := url.Parse(pathMapping.GetTarget()) if err != nil { s.Logger.WithFields(log.Fields{ "service_id": mapping.GetId(), "account_id": mapping.GetAccountId(), "domain": mapping.GetDomain(), "path": pathMapping.GetPath(), "target": pathMapping.GetTarget(), }).WithError(err).Error("failed to parse target URL for path, skipping") s.notifyError(ctx, mapping, fmt.Errorf("invalid target URL %q for path %q: %w", pathMapping.GetTarget(), pathMapping.GetPath(), err)) continue } pt := &proxy.PathTarget{URL: targetURL} if opts := pathMapping.GetOptions(); opts != nil { pt.SkipTLSVerify = opts.GetSkipTlsVerify() pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite()) pt.CustomHeaders = opts.GetCustomHeaders() if d := opts.GetRequestTimeout(); d != nil { pt.RequestTimeout = d.AsDuration() } } pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout) paths[pathMapping.GetPath()] = pt } m := proxy.Mapping{ ID: types.ServiceID(mapping.GetId()), AccountID: types.AccountID(mapping.GetAccountId()), Host: mapping.GetDomain(), Paths: paths, PassHostHeader: mapping.GetPassHostHeader(), RewriteRedirects: mapping.GetRewriteRedirects(), } for _, ha := range mapping.GetAuth().GetHeaderAuths() { m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader()) } return m } func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode { switch mode { case proto.PathRewriteMode_PATH_REWRITE_PRESERVE: return proxy.PathRewritePreserve default: return proxy.PathRewriteDefault } } // debugEndpointAddr returns the address for the debug endpoint. // If addr is empty, it defaults to localhost:8444 for security. func debugEndpointAddr(addr string) string { if addr == "" { return defaultDebugAddr } return addr }