mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
1607 lines
52 KiB
Go
1607 lines
52 KiB
Go
// Package proxy runs a NetBird proxy server.
|
|
// It attempts to do everything it needs to do within the context
|
|
// of a single request to the server to try to reduce the amount
|
|
// of concurrency coordination that is required. However, it does
|
|
// run two additional routines in an error group for handling
|
|
// updates from the management server and running a separate
|
|
// HTTP server to handle ACME HTTP-01 challenges (if configured).
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/url"
|
|
"path/filepath"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/pires/go-proxyproto"
|
|
prometheus2 "github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
log "github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/exporters/prometheus"
|
|
"go.opentelemetry.io/otel/sdk/metric"
|
|
"golang.org/x/exp/maps"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/keepalive"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
|
"github.com/netbirdio/netbird/proxy/internal/acme"
|
|
"github.com/netbirdio/netbird/proxy/internal/auth"
|
|
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
|
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
|
"github.com/netbirdio/netbird/proxy/internal/debug"
|
|
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
|
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
|
"github.com/netbirdio/netbird/proxy/internal/health"
|
|
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
|
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
|
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
|
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
|
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
|
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
|
udprelay "github.com/netbirdio/netbird/proxy/internal/udp"
|
|
"github.com/netbirdio/netbird/proxy/web"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
"github.com/netbirdio/netbird/shared/management/proto"
|
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
|
)
|
|
|
|
// portRouter bundles a per-port Router with its listener and cancel func.
|
|
type portRouter struct {
|
|
router *nbtcp.Router
|
|
listener net.Listener
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
type Server struct {
|
|
mgmtClient proto.ProxyServiceClient
|
|
proxy *proxy.ReverseProxy
|
|
netbird *roundtrip.NetBird
|
|
acme *acme.Manager
|
|
auth *auth.Middleware
|
|
http *http.Server
|
|
https *http.Server
|
|
debug *http.Server
|
|
healthServer *health.Server
|
|
healthChecker *health.Checker
|
|
meter *proxymetrics.Metrics
|
|
accessLog *accesslog.Logger
|
|
mainRouter *nbtcp.Router
|
|
mainPort uint16
|
|
udpMu sync.Mutex
|
|
udpRelays map[types.ServiceID]*udprelay.Relay
|
|
udpRelayWg sync.WaitGroup
|
|
portMu sync.RWMutex
|
|
portRouters map[uint16]*portRouter
|
|
svcPorts map[types.ServiceID][]uint16
|
|
lastMappings map[types.ServiceID]*proto.ProxyMapping
|
|
portRouterWg sync.WaitGroup
|
|
|
|
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
|
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
|
// does not handle them.
|
|
hijackTracker conntrack.HijackTracker
|
|
// geo resolves IP addresses to country/city for access restrictions and access logs.
|
|
geo restrict.GeoResolver
|
|
geoRaw *geolocation.Lookup
|
|
|
|
// routerReady is closed once mainRouter is fully initialized.
|
|
// The mapping worker waits on this before processing updates.
|
|
routerReady chan struct{}
|
|
|
|
// Mostly used for debugging on management.
|
|
startTime time.Time
|
|
|
|
ID string
|
|
Logger *log.Logger
|
|
Version string
|
|
ProxyURL string
|
|
ManagementAddress string
|
|
CertificateDirectory string
|
|
CertificateFile string
|
|
CertificateKeyFile string
|
|
GenerateACMECertificates bool
|
|
ACMEChallengeAddress string
|
|
ACMEDirectory string
|
|
// ACMEEABKID is the External Account Binding Key ID for CAs that require EAB (e.g., ZeroSSL).
|
|
ACMEEABKID string
|
|
// ACMEEABHMACKey is the External Account Binding HMAC key (base64 URL-encoded) for CAs that require EAB.
|
|
ACMEEABHMACKey string
|
|
// ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01".
|
|
// Defaults to "tls-alpn-01" if not specified.
|
|
ACMEChallengeType string
|
|
// CertLockMethod controls how ACME certificate locks are coordinated
|
|
// across replicas. Default: CertLockAuto (detect environment).
|
|
CertLockMethod acme.CertLockMethod
|
|
// WildcardCertDir is an optional directory containing wildcard certificate
|
|
// pairs (<name>.crt / <name>.key). Wildcard patterns are extracted from
|
|
// the certificates' SAN lists. Matching domains use these static certs
|
|
// instead of ACME.
|
|
WildcardCertDir string
|
|
|
|
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
|
DebugEndpointEnabled bool
|
|
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
|
DebugEndpointAddress string
|
|
// HealthAddress is the address for the health probe endpoint.
|
|
HealthAddress string
|
|
// ProxyToken is the access token for authenticating with the management server.
|
|
ProxyToken string
|
|
// ForwardedProto overrides the X-Forwarded-Proto value sent to backends.
|
|
// Valid values: "auto" (detect from TLS), "http", "https".
|
|
ForwardedProto string
|
|
// TrustedProxies is a list of IP prefixes for trusted upstream proxies.
|
|
// When set, forwarding headers from these sources are preserved and
|
|
// appended to instead of being stripped.
|
|
TrustedProxies []netip.Prefix
|
|
// WireguardPort is the port for the NetBird tunnel interface. Use 0
|
|
// for a random OS-assigned port. A fixed port only works with
|
|
// single-account deployments; multiple accounts will fail to bind
|
|
// the same port.
|
|
WireguardPort uint16
|
|
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
|
// When enabled, the real client IP is extracted from the PROXY header
|
|
// sent by upstream L4 proxies that support PROXY protocol.
|
|
ProxyProtocol bool
|
|
// PreSharedKey used for tunnel between proxy and peers (set globally not per account)
|
|
PreSharedKey string
|
|
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
|
// ports for TCP/UDP/TLS services.
|
|
SupportsCustomPorts bool
|
|
// RequireSubdomain indicates whether a subdomain label is required
|
|
// in front of this proxy's cluster domain. When true, accounts cannot
|
|
// create services on the bare cluster domain.
|
|
RequireSubdomain bool
|
|
// MaxDialTimeout caps the per-service backend dial timeout.
|
|
// When the API sends a timeout, it is clamped to this value.
|
|
// When the API sends no timeout, this value is used as the default.
|
|
// Zero means no cap (the proxy honors whatever management sends).
|
|
MaxDialTimeout time.Duration
|
|
// GeoDataDir is the directory containing GeoLite2 MMDB files for
|
|
// country-based access restrictions. Empty disables geo lookups.
|
|
GeoDataDir string
|
|
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
|
// Zero means no cap (the proxy honors whatever management sends).
|
|
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
|
MaxSessionIdleTimeout time.Duration
|
|
}
|
|
|
|
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
|
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
|
|
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
|
|
return s.MaxSessionIdleTimeout
|
|
}
|
|
return d
|
|
}
|
|
|
|
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
|
|
// If d is zero, MaxDialTimeout is used as the default.
|
|
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
|
|
if s.MaxDialTimeout <= 0 {
|
|
return d
|
|
}
|
|
if d <= 0 || d > s.MaxDialTimeout {
|
|
return s.MaxDialTimeout
|
|
}
|
|
return d
|
|
}
|
|
|
|
// NotifyStatus sends a status update to management about tunnel connectivity.
|
|
func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
|
|
status := proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED
|
|
if connected {
|
|
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
|
|
}
|
|
|
|
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
|
|
ServiceId: string(serviceID),
|
|
AccountId: string(accountID),
|
|
Status: status,
|
|
CertificateIssued: false,
|
|
})
|
|
return err
|
|
}
|
|
|
|
// NotifyCertificateIssued sends a notification to management that a certificate was issued
|
|
func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error {
|
|
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
|
|
ServiceId: string(serviceID),
|
|
AccountId: string(accountID),
|
|
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
|
|
CertificateIssued: true,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|
s.initDefaults()
|
|
s.routerReady = make(chan struct{})
|
|
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
|
|
s.portRouters = make(map[uint16]*portRouter)
|
|
s.svcPorts = make(map[types.ServiceID][]uint16)
|
|
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
|
|
|
|
exporter, err := prometheus.New()
|
|
if err != nil {
|
|
return fmt.Errorf("create prometheus exporter: %w", err)
|
|
}
|
|
|
|
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
|
pkg := reflect.TypeOf(Server{}).PkgPath()
|
|
meter := provider.Meter(pkg)
|
|
|
|
s.meter, err = proxymetrics.New(ctx, meter)
|
|
if err != nil {
|
|
return fmt.Errorf("create metrics: %w", err)
|
|
}
|
|
|
|
mgmtConn, err := s.dialManagement()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err := mgmtConn.Close(); err != nil {
|
|
s.Logger.Debugf("management connection close: %v", err)
|
|
}
|
|
}()
|
|
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
|
runCtx, runCancel := context.WithCancel(ctx)
|
|
defer runCancel()
|
|
|
|
// Initialize the netbird client, this is required to build peer connections
|
|
// to proxy over.
|
|
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
|
MgmtAddr: s.ManagementAddress,
|
|
WGPort: s.WireguardPort,
|
|
PreSharedKey: s.PreSharedKey,
|
|
}, s.Logger, s, s.mgmtClient)
|
|
|
|
// Create health checker before the mapping worker so it can track
|
|
// management connectivity from the first stream connection.
|
|
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
|
|
|
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
|
|
|
|
tlsConfig, err := s.configureTLS(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
|
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
|
|
|
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
|
if err != nil {
|
|
return fmt.Errorf("initialize geolocation: %w", err)
|
|
}
|
|
s.geoRaw = geoLookup
|
|
if geoLookup != nil {
|
|
s.geo = geoLookup
|
|
}
|
|
|
|
var startupOK bool
|
|
defer func() {
|
|
if startupOK {
|
|
return
|
|
}
|
|
if s.geoRaw != nil {
|
|
if err := s.geoRaw.Close(); err != nil {
|
|
s.Logger.Debugf("close geolocation on startup failure: %v", err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Configure the authentication middleware with session validator for OIDC group checks.
|
|
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
|
|
|
|
// Configure Access logs to management server.
|
|
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
|
|
|
s.startDebugEndpoint()
|
|
|
|
if err := s.startHealthServer(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build the handler chain from inside out.
|
|
handler := http.Handler(s.proxy)
|
|
handler = s.auth.Protect(handler)
|
|
handler = web.AssetHandler(handler)
|
|
handler = s.accessLog.Middleware(handler)
|
|
handler = s.meter.Middleware(handler)
|
|
handler = s.hijackTracker.Middleware(handler)
|
|
|
|
// Start a raw TCP listener; the SNI router peeks at ClientHello
|
|
// and routes to either the HTTP handler or a TCP relay.
|
|
lc := net.ListenConfig{}
|
|
ln, err := lc.Listen(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen on %s: %w", addr, err)
|
|
}
|
|
if s.ProxyProtocol {
|
|
ln = s.wrapProxyProtocol(ln)
|
|
}
|
|
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
|
|
|
|
// Set up the SNI router for TCP/HTTP multiplexing on the main port.
|
|
s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr())
|
|
s.mainRouter.SetObserver(s.meter)
|
|
s.mainRouter.SetAccessLogger(s.accessLog)
|
|
close(s.routerReady)
|
|
|
|
// The HTTP server uses the chanListener fed by the SNI router.
|
|
s.https = &http.Server{
|
|
Addr: addr,
|
|
Handler: handler,
|
|
TLSConfig: tlsConfig,
|
|
ReadHeaderTimeout: httpReadHeaderTimeout,
|
|
IdleTimeout: httpIdleTimeout,
|
|
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
|
}
|
|
|
|
startupOK = true
|
|
|
|
httpsErr := make(chan error, 1)
|
|
go func() {
|
|
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
|
|
httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "")
|
|
}()
|
|
|
|
routerErr := make(chan error, 1)
|
|
go func() {
|
|
s.Logger.Debugf("starting SNI router on %s", addr)
|
|
routerErr <- s.mainRouter.Serve(runCtx, ln)
|
|
}()
|
|
|
|
select {
|
|
case err := <-httpsErr:
|
|
s.shutdownServices()
|
|
if !errors.Is(err, http.ErrServerClosed) {
|
|
return fmt.Errorf("https server: %w", err)
|
|
}
|
|
return nil
|
|
case err := <-routerErr:
|
|
s.shutdownServices()
|
|
if err != nil {
|
|
return fmt.Errorf("SNI router: %w", err)
|
|
}
|
|
return nil
|
|
case <-ctx.Done():
|
|
s.gracefulShutdown()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// initDefaults sets fallback values for optional Server fields.
|
|
func (s *Server) initDefaults() {
|
|
s.startTime = time.Now()
|
|
|
|
// If no ID is set then one can be generated.
|
|
if s.ID == "" {
|
|
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
|
}
|
|
// Fallback version option in case it is not set.
|
|
if s.Version == "" {
|
|
s.Version = "dev"
|
|
}
|
|
|
|
// If no logger is specified fallback to the standard logger.
|
|
if s.Logger == nil {
|
|
s.Logger = log.StandardLogger()
|
|
}
|
|
}
|
|
|
|
// startDebugEndpoint launches the debug HTTP server if enabled.
|
|
func (s *Server) startDebugEndpoint() {
|
|
if !s.DebugEndpointEnabled {
|
|
return
|
|
}
|
|
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
|
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
|
if s.acme != nil {
|
|
debugHandler.SetCertStatus(s.acme)
|
|
}
|
|
s.debug = &http.Server{
|
|
Addr: debugAddr,
|
|
Handler: debugHandler,
|
|
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
|
}
|
|
go func() {
|
|
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
|
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
s.Logger.Errorf("debug endpoint error: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// startHealthServer launches the health probe and metrics server.
|
|
func (s *Server) startHealthServer() error {
|
|
healthAddr := s.HealthAddress
|
|
if healthAddr == "" {
|
|
healthAddr = defaultHealthAddr
|
|
}
|
|
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
|
|
healthListener, err := net.Listen("tcp", healthAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
|
}
|
|
go func() {
|
|
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
s.Logger.Errorf("health probe server: %v", err)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
// wrapProxyProtocol wraps a listener with PROXY protocol support.
|
|
// When TrustedProxies is configured, only those sources may send PROXY headers;
|
|
// connections from untrusted sources have any PROXY header ignored.
|
|
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
|
|
ppListener := &proxyproto.Listener{
|
|
Listener: ln,
|
|
ReadHeaderTimeout: proxyProtoHeaderTimeout,
|
|
}
|
|
if len(s.TrustedProxies) > 0 {
|
|
ppListener.ConnPolicy = s.proxyProtocolPolicy
|
|
} else {
|
|
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
|
|
}
|
|
s.Logger.Info("PROXY protocol enabled on listener")
|
|
return ppListener
|
|
}
|
|
|
|
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
|
|
// header based on whether the connection source is in TrustedProxies.
|
|
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
|
|
// No logging on reject to prevent abuse
|
|
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
|
|
if !ok {
|
|
return proxyproto.REJECT, nil
|
|
}
|
|
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
|
|
if !ok {
|
|
return proxyproto.REJECT, nil
|
|
}
|
|
addr = addr.Unmap()
|
|
|
|
// called per accept
|
|
for _, prefix := range s.TrustedProxies {
|
|
if prefix.Contains(addr) {
|
|
return proxyproto.REQUIRE, nil
|
|
}
|
|
}
|
|
return proxyproto.IGNORE, nil
|
|
}
|
|
|
|
const (
|
|
defaultHealthAddr = "localhost:8080"
|
|
defaultDebugAddr = "localhost:8444"
|
|
|
|
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
|
// header after accepting a connection.
|
|
proxyProtoHeaderTimeout = 5 * time.Second
|
|
|
|
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
|
// before draining connections. This allows the load balancer to propagate
|
|
// the endpoint removal.
|
|
shutdownPreStopDelay = 5 * time.Second
|
|
|
|
// shutdownDrainTimeout is the maximum time to wait for in-flight HTTP
|
|
// requests to complete during graceful shutdown.
|
|
shutdownDrainTimeout = 30 * time.Second
|
|
|
|
// shutdownServiceTimeout is the maximum time to wait for auxiliary
|
|
// services (health probe, debug endpoint, ACME) to shut down.
|
|
shutdownServiceTimeout = 5 * time.Second
|
|
|
|
// httpReadHeaderTimeout limits how long the server waits to read
|
|
// request headers after accepting a connection. Prevents slowloris.
|
|
httpReadHeaderTimeout = 10 * time.Second
|
|
// httpIdleTimeout limits how long an idle keep-alive connection
|
|
// stays open before the server closes it.
|
|
httpIdleTimeout = 120 * time.Second
|
|
)
|
|
|
|
func (s *Server) dialManagement() (*grpc.ClientConn, error) {
|
|
mgmtURL, err := url.Parse(s.ManagementAddress)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse management address: %w", err)
|
|
}
|
|
creds := insecure.NewCredentials()
|
|
// Assume management TLS is enabled for gRPC as well if using HTTPS for the API.
|
|
if mgmtURL.Scheme == "https" {
|
|
certPool, err := x509.SystemCertPool()
|
|
if err != nil || certPool == nil {
|
|
// Fall back to embedded CAs if no OS-provided ones are available.
|
|
certPool = embeddedroots.Get()
|
|
}
|
|
creds = credentials.NewTLS(&tls.Config{
|
|
RootCAs: certPool,
|
|
})
|
|
}
|
|
s.Logger.WithFields(log.Fields{
|
|
"gRPC_address": mgmtURL.Host,
|
|
"TLS_enabled": mgmtURL.Scheme == "https",
|
|
}).Debug("starting management gRPC client")
|
|
conn, err := grpc.NewClient(mgmtURL.Host,
|
|
grpc.WithTransportCredentials(creds),
|
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: 20 * time.Second,
|
|
Timeout: 10 * time.Second,
|
|
PermitWithoutStream: true,
|
|
}),
|
|
proxygrpc.WithProxyToken(s.ProxyToken),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create management connection: %w", err)
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
|
tlsConfig := &tls.Config{}
|
|
if !s.GenerateACMECertificates {
|
|
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
|
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
|
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
|
|
|
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
|
|
}
|
|
go certWatcher.Watch(ctx)
|
|
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
if s.ACMEChallengeType == "" {
|
|
s.ACMEChallengeType = "tls-alpn-01"
|
|
}
|
|
s.Logger.WithFields(log.Fields{
|
|
"acme_server": s.ACMEDirectory,
|
|
"challenge_type": s.ACMEChallengeType,
|
|
}).Debug("ACME certificates enabled, configuring certificate manager")
|
|
var err error
|
|
s.acme, err = acme.NewManager(acme.ManagerConfig{
|
|
CertDir: s.CertificateDirectory,
|
|
ACMEURL: s.ACMEDirectory,
|
|
EABKID: s.ACMEEABKID,
|
|
EABHMACKey: s.ACMEEABHMACKey,
|
|
LockMethod: s.CertLockMethod,
|
|
WildcardDir: s.WildcardCertDir,
|
|
}, s, s.Logger, s.meter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create ACME manager: %w", err)
|
|
}
|
|
|
|
go s.acme.WatchWildcards(ctx)
|
|
|
|
if s.ACMEChallengeType == "http-01" {
|
|
s.http = &http.Server{
|
|
Addr: s.ACMEChallengeAddress,
|
|
Handler: s.acme.HTTPHandler(nil),
|
|
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueACME),
|
|
}
|
|
go func() {
|
|
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
|
}
|
|
}()
|
|
}
|
|
tlsConfig = s.acme.TLSConfig()
|
|
|
|
// autocert.Manager.TLSConfig() wires its own GetCertificate, which
|
|
// bypasses our override that checks wildcards first.
|
|
tlsConfig.GetCertificate = s.acme.GetCertificate
|
|
|
|
// ServerName needs to be set to allow for ACME to work correctly
|
|
// when using CNAME URLs to access the proxy.
|
|
tlsConfig.ServerName = s.ProxyURL
|
|
|
|
s.Logger.WithFields(log.Fields{
|
|
"ServerName": s.ProxyURL,
|
|
"challenge_type": s.ACMEChallengeType,
|
|
}).Debug("ACME certificate manager configured")
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
|
|
// readiness probe as failing, waits for load balancer propagation, drains
|
|
// in-flight connections, and then stops all background services.
|
|
func (s *Server) gracefulShutdown() {
|
|
s.Logger.Info("shutdown signal received, starting graceful shutdown")
|
|
|
|
// Step 1: Fail readiness probe so load balancers stop routing new traffic.
|
|
if s.healthChecker != nil {
|
|
s.healthChecker.SetShuttingDown()
|
|
}
|
|
|
|
// Step 2: When running behind a load balancer, wait for endpoint removal
|
|
// to propagate before draining connections.
|
|
if k8s.InCluster() {
|
|
s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay)
|
|
time.Sleep(shutdownPreStopDelay)
|
|
}
|
|
|
|
// Step 3: Stop accepting new connections and drain in-flight requests.
|
|
drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
|
|
defer drainCancel()
|
|
|
|
s.Logger.Info("draining in-flight connections")
|
|
if err := s.https.Shutdown(drainCtx); err != nil {
|
|
s.Logger.Warnf("https server drain: %v", err)
|
|
}
|
|
|
|
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
|
if n := s.hijackTracker.CloseAll(); n > 0 {
|
|
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
|
}
|
|
|
|
// Drain all router relay connections (main + per-port) in parallel.
|
|
s.drainAllRouters(shutdownDrainTimeout)
|
|
|
|
// Step 5: Stop all remaining background services.
|
|
s.shutdownServices()
|
|
s.Logger.Info("graceful shutdown complete")
|
|
}
|
|
|
|
// shutdownServices stops all background services concurrently and waits for
|
|
// them to finish.
|
|
// drainAllRouters drains active relay connections on the main router and
|
|
// all per-port routers in parallel, up to the given timeout.
|
|
func (s *Server) drainAllRouters(timeout time.Duration) {
|
|
var wg sync.WaitGroup
|
|
|
|
drain := func(name string, router *nbtcp.Router) {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if ok := router.Drain(timeout); !ok {
|
|
s.Logger.Warnf("timed out draining %s relay connections", name)
|
|
}
|
|
}()
|
|
}
|
|
|
|
if s.mainRouter != nil {
|
|
drain("main router", s.mainRouter)
|
|
}
|
|
|
|
s.portMu.RLock()
|
|
for port, pr := range s.portRouters {
|
|
drain(fmt.Sprintf("port %d", port), pr.router)
|
|
}
|
|
s.portMu.RUnlock()
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (s *Server) shutdownServices() {
|
|
var wg sync.WaitGroup
|
|
|
|
shutdownHTTP := func(name string, shutdown func(context.Context) error) {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout)
|
|
defer cancel()
|
|
if err := shutdown(ctx); err != nil {
|
|
s.Logger.Debugf("%s shutdown: %v", name, err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
if s.healthServer != nil {
|
|
shutdownHTTP("health probe", s.healthServer.Shutdown)
|
|
}
|
|
if s.debug != nil {
|
|
shutdownHTTP("debug endpoint", s.debug.Shutdown)
|
|
}
|
|
if s.http != nil {
|
|
shutdownHTTP("acme http", s.http.Shutdown)
|
|
}
|
|
|
|
if s.netbird != nil {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
|
|
defer cancel()
|
|
if err := s.netbird.StopAll(ctx); err != nil {
|
|
s.Logger.Warnf("stop netbird clients: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Close all UDP relays and wait for their goroutines to exit.
|
|
s.udpMu.Lock()
|
|
for id, relay := range s.udpRelays {
|
|
relay.Close()
|
|
delete(s.udpRelays, id)
|
|
}
|
|
s.udpMu.Unlock()
|
|
s.udpRelayWg.Wait()
|
|
|
|
// Close all per-port routers.
|
|
s.portMu.Lock()
|
|
for port, pr := range s.portRouters {
|
|
pr.cancel()
|
|
if err := pr.listener.Close(); err != nil {
|
|
s.Logger.Debugf("close listener on port %d: %v", port, err)
|
|
}
|
|
delete(s.portRouters, port)
|
|
}
|
|
maps.Clear(s.svcPorts)
|
|
maps.Clear(s.lastMappings)
|
|
s.portMu.Unlock()
|
|
|
|
// Wait for per-port router serve goroutines to exit.
|
|
s.portRouterWg.Wait()
|
|
|
|
wg.Wait()
|
|
|
|
if s.accessLog != nil {
|
|
s.accessLog.Close()
|
|
}
|
|
|
|
if s.geoRaw != nil {
|
|
if err := s.geoRaw.Close(); err != nil {
|
|
s.Logger.Debugf("close geolocation: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// resolveDialFunc returns a DialContextFunc that dials through the
|
|
// NetBird tunnel for the given account.
|
|
func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFunc, error) {
|
|
client, ok := s.netbird.GetClient(accountID)
|
|
if !ok {
|
|
return nil, fmt.Errorf("no client for account %s", accountID)
|
|
}
|
|
return client.DialContext, nil
|
|
}
|
|
|
|
// notifyError reports a resource error back to management so it can be
|
|
// surfaced to the user (e.g. port bind failure, dialer resolution error).
|
|
func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) {
|
|
s.sendStatusUpdate(ctx, types.AccountID(mapping.GetAccountId()), types.ServiceID(mapping.GetId()), proto.ProxyStatus_PROXY_STATUS_ERROR, err)
|
|
}
|
|
|
|
// sendStatusUpdate sends a status update for a service to management.
|
|
func (s *Server) sendStatusUpdate(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, st proto.ProxyStatus, err error) {
|
|
req := &proto.SendStatusUpdateRequest{
|
|
ServiceId: string(serviceID),
|
|
AccountId: string(accountID),
|
|
Status: st,
|
|
}
|
|
if err != nil {
|
|
msg := err.Error()
|
|
req.ErrorMessage = &msg
|
|
}
|
|
if _, sendErr := s.mgmtClient.SendStatusUpdate(ctx, req); sendErr != nil {
|
|
s.Logger.Debugf("failed to send status update for %s: %v", serviceID, sendErr)
|
|
}
|
|
}
|
|
|
|
// routerForPort returns the router that handles the given listen port. If port
|
|
// is 0 or matches the main listener port, the main router is returned.
|
|
// Otherwise a new per-port router is created and started.
|
|
func (s *Server) routerForPort(ctx context.Context, port uint16) (*nbtcp.Router, error) {
|
|
if port == 0 || port == s.mainPort {
|
|
return s.mainRouter, nil
|
|
}
|
|
return s.getOrCreatePortRouter(ctx, port)
|
|
}
|
|
|
|
// routerForPortExisting returns the router for the given port without creating
|
|
// one. Returns the main router for port 0 / mainPort, or nil if no per-port
|
|
// router exists.
|
|
func (s *Server) routerForPortExisting(port uint16) *nbtcp.Router {
|
|
if port == 0 || port == s.mainPort {
|
|
return s.mainRouter
|
|
}
|
|
s.portMu.RLock()
|
|
pr := s.portRouters[port]
|
|
s.portMu.RUnlock()
|
|
if pr != nil {
|
|
return pr.router
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getOrCreatePortRouter returns an existing per-port router or creates one
|
|
// with a new TCP listener and starts serving.
|
|
func (s *Server) getOrCreatePortRouter(ctx context.Context, port uint16) (*nbtcp.Router, error) {
|
|
s.portMu.Lock()
|
|
defer s.portMu.Unlock()
|
|
|
|
if pr, ok := s.portRouters[port]; ok {
|
|
return pr.router, nil
|
|
}
|
|
|
|
listenAddr := fmt.Sprintf(":%d", port)
|
|
ln, err := net.Listen("tcp", listenAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listen TCP on %s: %w", listenAddr, err)
|
|
}
|
|
if s.ProxyProtocol {
|
|
ln = s.wrapProxyProtocol(ln)
|
|
}
|
|
|
|
router := nbtcp.NewPortRouter(s.Logger, s.resolveDialFunc)
|
|
router.SetObserver(s.meter)
|
|
router.SetAccessLogger(s.accessLog)
|
|
portCtx, cancel := context.WithCancel(ctx)
|
|
|
|
s.portRouters[port] = &portRouter{
|
|
router: router,
|
|
listener: ln,
|
|
cancel: cancel,
|
|
}
|
|
|
|
s.portRouterWg.Add(1)
|
|
go func() {
|
|
defer s.portRouterWg.Done()
|
|
if err := router.Serve(portCtx, ln); err != nil {
|
|
s.Logger.Debugf("port %d router stopped: %v", port, err)
|
|
}
|
|
}()
|
|
|
|
s.Logger.Debugf("started per-port router on %s", listenAddr)
|
|
return router, nil
|
|
}
|
|
|
|
// cleanupPortIfEmpty tears down a per-port router if it has no remaining
|
|
// routes or fallback. The main port is never cleaned up. Active relay
|
|
// connections are drained before the listener is closed.
|
|
func (s *Server) cleanupPortIfEmpty(port uint16) {
|
|
if port == 0 || port == s.mainPort {
|
|
return
|
|
}
|
|
|
|
s.portMu.Lock()
|
|
pr, ok := s.portRouters[port]
|
|
if !ok || !pr.router.IsEmpty() {
|
|
s.portMu.Unlock()
|
|
return
|
|
}
|
|
|
|
// Cancel and close the listener while holding the lock so that
|
|
// getOrCreatePortRouter sees the entry is gone before we drain.
|
|
pr.cancel()
|
|
if err := pr.listener.Close(); err != nil {
|
|
s.Logger.Debugf("close listener on port %d: %v", port, err)
|
|
}
|
|
delete(s.portRouters, port)
|
|
s.portMu.Unlock()
|
|
|
|
// Drain active relay connections outside the lock.
|
|
if ok := pr.router.Drain(nbtcp.DefaultDrainTimeout); !ok {
|
|
s.Logger.Warnf("timed out draining relay connections on port %d", port)
|
|
}
|
|
s.Logger.Debugf("cleaned up empty per-port router on port %d", port)
|
|
}
|
|
|
|
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {
|
|
bo := &backoff.ExponentialBackOff{
|
|
InitialInterval: 800 * time.Millisecond,
|
|
RandomizationFactor: 1,
|
|
Multiplier: 1.7,
|
|
MaxInterval: 10 * time.Second,
|
|
MaxElapsedTime: 0, // retry indefinitely until context is canceled
|
|
Stop: backoff.Stop,
|
|
Clock: backoff.SystemClock,
|
|
}
|
|
|
|
initialSyncDone := false
|
|
|
|
operation := func() error {
|
|
s.Logger.Debug("connecting to management mapping stream")
|
|
|
|
if s.healthChecker != nil {
|
|
s.healthChecker.SetManagementConnected(false)
|
|
}
|
|
|
|
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
|
ProxyId: s.ID,
|
|
Version: s.Version,
|
|
StartedAt: timestamppb.New(s.startTime),
|
|
Address: s.ProxyURL,
|
|
Capabilities: &proto.ProxyCapabilities{
|
|
SupportsCustomPorts: &s.SupportsCustomPorts,
|
|
RequireSubdomain: &s.RequireSubdomain,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("create mapping stream: %w", err)
|
|
}
|
|
|
|
if s.healthChecker != nil {
|
|
s.healthChecker.SetManagementConnected(true)
|
|
}
|
|
s.Logger.Debug("management mapping stream established")
|
|
|
|
// Stream established — reset backoff so the next failure retries quickly.
|
|
bo.Reset()
|
|
|
|
streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone)
|
|
|
|
if s.healthChecker != nil {
|
|
s.healthChecker.SetManagementConnected(false)
|
|
}
|
|
|
|
if streamErr == nil {
|
|
return fmt.Errorf("stream closed by server")
|
|
}
|
|
|
|
return fmt.Errorf("mapping stream: %w", streamErr)
|
|
}
|
|
|
|
notify := func(err error, next time.Duration) {
|
|
s.Logger.Warnf("management connection failed, retrying in %s: %v", next.Truncate(time.Millisecond), err)
|
|
}
|
|
|
|
if err := backoff.RetryNotify(operation, backoff.WithContext(bo, ctx), notify); err != nil {
|
|
s.Logger.WithError(err).Debug("management mapping worker exiting")
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error {
|
|
select {
|
|
case <-s.routerReady:
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
|
|
for {
|
|
// Check for context completion to gracefully shutdown.
|
|
select {
|
|
case <-ctx.Done():
|
|
// Shutting down.
|
|
return ctx.Err()
|
|
default:
|
|
msg, err := mappingClient.Recv()
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
// Mapping connection gracefully terminated by server.
|
|
return nil
|
|
case err != nil:
|
|
// Something has gone horribly wrong, return and hope the parent retries the connection.
|
|
return fmt.Errorf("receive msg: %w", err)
|
|
}
|
|
s.Logger.Debug("Received mapping update, starting processing")
|
|
s.processMappings(ctx, msg.GetMapping())
|
|
s.Logger.Debug("Processing mapping update completed")
|
|
|
|
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
|
if s.healthChecker != nil {
|
|
s.healthChecker.SetInitialSyncComplete()
|
|
}
|
|
*initialSyncDone = true
|
|
s.Logger.Info("Initial mapping sync complete")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
|
for _, mapping := range mappings {
|
|
s.Logger.WithFields(log.Fields{
|
|
"type": mapping.GetType(),
|
|
"domain": mapping.GetDomain(),
|
|
"mode": mapping.GetMode(),
|
|
"port": mapping.GetListenPort(),
|
|
"id": mapping.GetId(),
|
|
}).Debug("Processing mapping update")
|
|
switch mapping.GetType() {
|
|
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
|
if err := s.addMapping(ctx, mapping); err != nil {
|
|
s.Logger.WithFields(log.Fields{
|
|
"service_id": mapping.GetId(),
|
|
"domain": mapping.GetDomain(),
|
|
"error": err,
|
|
}).Error("Error adding new mapping, ignoring this mapping and continuing processing")
|
|
s.notifyError(ctx, mapping, err)
|
|
}
|
|
case proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED:
|
|
if err := s.modifyMapping(ctx, mapping); err != nil {
|
|
s.Logger.WithFields(log.Fields{
|
|
"service_id": mapping.GetId(),
|
|
"domain": mapping.GetDomain(),
|
|
"error": err,
|
|
}).Error("failed to modify mapping")
|
|
s.notifyError(ctx, mapping, err)
|
|
}
|
|
case proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED:
|
|
s.removeMapping(ctx, mapping)
|
|
}
|
|
}
|
|
}
|
|
|
|
// addMapping registers a service mapping and starts the appropriate relay or routes.
|
|
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
authToken := mapping.GetAuthToken()
|
|
|
|
svcKey := s.serviceKeyForMapping(mapping)
|
|
if err := s.netbird.AddPeer(ctx, accountID, svcKey, authToken, svcID); err != nil {
|
|
return fmt.Errorf("create peer for service %s: %w", svcID, err)
|
|
}
|
|
|
|
if err := s.setupMappingRoutes(ctx, mapping); err != nil {
|
|
s.cleanupMappingRoutes(mapping)
|
|
if peerErr := s.netbird.RemovePeer(ctx, accountID, svcKey); peerErr != nil {
|
|
s.Logger.WithError(peerErr).WithField("service_id", svcID).Warn("failed to remove peer after setup failure")
|
|
}
|
|
return err
|
|
}
|
|
s.storeMapping(mapping)
|
|
return nil
|
|
}
|
|
|
|
// modifyMapping updates a service mapping in place without tearing down the
|
|
// NetBird peer. It cleans up old routes using the previously stored mapping
|
|
// state and re-applies them from the new mapping.
|
|
func (s *Server) modifyMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
if old := s.loadMapping(types.ServiceID(mapping.GetId())); old != nil {
|
|
s.cleanupMappingRoutes(old)
|
|
if mode := types.ServiceMode(old.GetMode()); mode.IsL4() {
|
|
s.meter.L4ServiceRemoved(mode)
|
|
}
|
|
} else {
|
|
s.cleanupMappingRoutes(mapping)
|
|
}
|
|
if err := s.setupMappingRoutes(ctx, mapping); err != nil {
|
|
s.cleanupMappingRoutes(mapping)
|
|
return err
|
|
}
|
|
s.storeMapping(mapping)
|
|
return nil
|
|
}
|
|
|
|
// setupMappingRoutes configures the appropriate routes or relays for the given
|
|
// service mapping based on its mode. The NetBird peer must already exist.
|
|
func (s *Server) setupMappingRoutes(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
switch types.ServiceMode(mapping.GetMode()) {
|
|
case types.ServiceModeTCP:
|
|
return s.setupTCPMapping(ctx, mapping)
|
|
case types.ServiceModeUDP:
|
|
return s.setupUDPMapping(ctx, mapping)
|
|
case types.ServiceModeTLS:
|
|
return s.setupTLSMapping(ctx, mapping)
|
|
default:
|
|
return s.setupHTTPMapping(ctx, mapping)
|
|
}
|
|
}
|
|
|
|
// setupHTTPMapping configures HTTP reverse proxy, auth, and ACME routes.
|
|
func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
d := domain.Domain(mapping.GetDomain())
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
|
|
if len(mapping.GetPath()) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var wildcardHit bool
|
|
if s.acme != nil {
|
|
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
|
|
}
|
|
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
|
|
Type: nbtcp.RouteHTTP,
|
|
AccountID: accountID,
|
|
ServiceID: svcID,
|
|
Domain: mapping.GetDomain(),
|
|
})
|
|
if err := s.updateMapping(ctx, mapping); err != nil {
|
|
return fmt.Errorf("update mapping for domain %q: %w", d, err)
|
|
}
|
|
|
|
if wildcardHit {
|
|
if err := s.NotifyCertificateIssued(ctx, accountID, svcID, string(d)); err != nil {
|
|
s.Logger.Warnf("notify certificate ready for domain %q: %v", d, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port.
|
|
func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
|
|
port, err := netutil.ValidatePort(mapping.GetListenPort())
|
|
if err != nil {
|
|
return fmt.Errorf("TCP service %s: %w", svcID, err)
|
|
}
|
|
|
|
targetAddr := s.l4TargetAddress(mapping)
|
|
if targetAddr == "" {
|
|
return fmt.Errorf("empty target address for TCP service %s", svcID)
|
|
}
|
|
|
|
if s.WireguardPort != 0 && port == s.WireguardPort {
|
|
return fmt.Errorf("port %d conflicts with tunnel port", port)
|
|
}
|
|
|
|
router, err := s.routerForPort(ctx, port)
|
|
if err != nil {
|
|
return fmt.Errorf("router for TCP port %d: %w", port, err)
|
|
}
|
|
|
|
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
|
|
|
router.SetGeo(s.geo)
|
|
router.SetFallback(nbtcp.Route{
|
|
Type: nbtcp.RouteTCP,
|
|
AccountID: accountID,
|
|
ServiceID: svcID,
|
|
Domain: mapping.GetDomain(),
|
|
Protocol: accesslog.ProtocolTCP,
|
|
Target: targetAddr,
|
|
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
|
DialTimeout: s.l4DialTimeout(mapping),
|
|
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
|
Filter: parseRestrictions(mapping),
|
|
})
|
|
|
|
s.portMu.Lock()
|
|
s.svcPorts[svcID] = []uint16{port}
|
|
s.portMu.Unlock()
|
|
|
|
s.meter.L4ServiceAdded(types.ServiceModeTCP)
|
|
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
|
|
return nil
|
|
}
|
|
|
|
// setupUDPMapping starts a UDP relay on the listen port.
|
|
func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
|
|
port, err := netutil.ValidatePort(mapping.GetListenPort())
|
|
if err != nil {
|
|
return fmt.Errorf("UDP service %s: %w", svcID, err)
|
|
}
|
|
|
|
targetAddr := s.l4TargetAddress(mapping)
|
|
if targetAddr == "" {
|
|
return fmt.Errorf("empty target address for UDP service %s", svcID)
|
|
}
|
|
|
|
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
|
|
|
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
|
|
return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
|
|
}
|
|
|
|
s.meter.L4ServiceAdded(types.ServiceModeUDP)
|
|
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
|
|
return nil
|
|
}
|
|
|
|
// setupTLSMapping configures a TLS SNI-routed passthrough on the listen port.
|
|
func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
|
|
tlsPort, err := netutil.ValidatePort(mapping.GetListenPort())
|
|
if err != nil {
|
|
return fmt.Errorf("TLS service %s: %w", svcID, err)
|
|
}
|
|
|
|
targetAddr := s.l4TargetAddress(mapping)
|
|
if targetAddr == "" {
|
|
return fmt.Errorf("empty target address for TLS service %s", svcID)
|
|
}
|
|
|
|
if s.WireguardPort != 0 && tlsPort == s.WireguardPort {
|
|
return fmt.Errorf("port %d conflicts with tunnel port", tlsPort)
|
|
}
|
|
|
|
router, err := s.routerForPort(ctx, tlsPort)
|
|
if err != nil {
|
|
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
|
|
}
|
|
|
|
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
|
|
|
router.SetGeo(s.geo)
|
|
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
|
|
Type: nbtcp.RouteTCP,
|
|
AccountID: accountID,
|
|
ServiceID: svcID,
|
|
Domain: mapping.GetDomain(),
|
|
Protocol: accesslog.ProtocolTLS,
|
|
Target: targetAddr,
|
|
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
|
DialTimeout: s.l4DialTimeout(mapping),
|
|
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
|
Filter: parseRestrictions(mapping),
|
|
})
|
|
|
|
if tlsPort != s.mainPort {
|
|
s.portMu.Lock()
|
|
s.svcPorts[svcID] = []uint16{tlsPort}
|
|
s.portMu.Unlock()
|
|
}
|
|
|
|
s.Logger.WithFields(log.Fields{
|
|
"domain": mapping.GetDomain(),
|
|
"target": targetAddr,
|
|
"port": tlsPort,
|
|
"service": svcID,
|
|
}).Info("TLS passthrough mapping added")
|
|
|
|
s.meter.L4ServiceAdded(types.ServiceModeTLS)
|
|
s.sendStatusUpdate(ctx, accountID, svcID, proto.ProxyStatus_PROXY_STATUS_ACTIVE, nil)
|
|
return nil
|
|
}
|
|
|
|
// serviceKeyForMapping returns the appropriate ServiceKey for a mapping.
|
|
// TCP/UDP use an ID-based key; HTTP/TLS use a domain-based key.
|
|
func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.ServiceKey {
|
|
switch types.ServiceMode(mapping.GetMode()) {
|
|
case types.ServiceModeTCP, types.ServiceModeUDP:
|
|
return roundtrip.L4ServiceKey(types.ServiceID(mapping.GetId()))
|
|
default:
|
|
return roundtrip.DomainServiceKey(mapping.GetDomain())
|
|
}
|
|
}
|
|
|
|
// parseRestrictions converts a proto mapping's access restrictions into
|
|
// a restrict.Filter. Returns nil if the mapping has no restrictions.
|
|
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
|
|
r := mapping.GetAccessRestrictions()
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
|
|
}
|
|
|
|
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
|
|
// but the proxy has no geolocation database loaded. All requests to this
|
|
// service will be denied at runtime (fail-close).
|
|
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
|
|
if r == nil {
|
|
return
|
|
}
|
|
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
|
|
return
|
|
}
|
|
if s.geo != nil && s.geo.Available() {
|
|
return
|
|
}
|
|
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
|
|
}
|
|
|
|
// l4TargetAddress extracts and validates the target address from a mapping's
|
|
// first path entry. Returns empty string if no paths exist or the address is
|
|
// not a valid host:port.
|
|
func (s *Server) l4TargetAddress(mapping *proto.ProxyMapping) string {
|
|
paths := mapping.GetPath()
|
|
if len(paths) == 0 {
|
|
return ""
|
|
}
|
|
target := paths[0].GetTarget()
|
|
if _, _, err := net.SplitHostPort(target); err != nil {
|
|
s.Logger.WithFields(log.Fields{
|
|
"service_id": mapping.GetId(),
|
|
"target": target,
|
|
}).Warnf("invalid L4 target address: %v", err)
|
|
return ""
|
|
}
|
|
return target
|
|
}
|
|
|
|
// l4ProxyProtocol returns whether the first target has PROXY protocol enabled.
|
|
func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
|
|
paths := mapping.GetPath()
|
|
if len(paths) == 0 {
|
|
return false
|
|
}
|
|
return paths[0].GetOptions().GetProxyProtocol()
|
|
}
|
|
|
|
// l4DialTimeout returns the dial timeout from the first target's options,
|
|
// clamped to MaxDialTimeout.
|
|
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
|
|
paths := mapping.GetPath()
|
|
if len(paths) > 0 {
|
|
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
|
|
return s.clampDialTimeout(d.AsDuration())
|
|
}
|
|
}
|
|
return s.clampDialTimeout(0)
|
|
}
|
|
|
|
// l4SessionIdleTimeout returns the configured session idle timeout from the
|
|
// mapping options, or 0 to use the relay's default.
|
|
func l4SessionIdleTimeout(mapping *proto.ProxyMapping) time.Duration {
|
|
paths := mapping.GetPath()
|
|
if len(paths) > 0 {
|
|
if d := paths[0].GetOptions().GetSessionIdleTimeout(); d != nil {
|
|
return d.AsDuration()
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// addUDPRelay starts a UDP relay on the specified listen port.
|
|
func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, targetAddress string, listenPort uint16) error {
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
|
|
if s.WireguardPort != 0 && listenPort == s.WireguardPort {
|
|
return fmt.Errorf("UDP port %d conflicts with tunnel port", listenPort)
|
|
}
|
|
|
|
// Close existing relay if present (idempotent re-add).
|
|
s.removeUDPRelay(svcID)
|
|
|
|
listenAddr := fmt.Sprintf(":%d", listenPort)
|
|
|
|
listener, err := net.ListenPacket("udp", listenAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen UDP on %s: %w", listenAddr, err)
|
|
}
|
|
|
|
dialFn, err := s.resolveDialFunc(accountID)
|
|
if err != nil {
|
|
if err := listener.Close(); err != nil {
|
|
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
|
|
}
|
|
return fmt.Errorf("resolve dialer for UDP: %w", err)
|
|
}
|
|
|
|
entry := s.Logger.WithFields(log.Fields{
|
|
"target": targetAddress,
|
|
"listen_port": listenPort,
|
|
"service_id": svcID,
|
|
})
|
|
|
|
relay := udprelay.New(ctx, udprelay.RelayConfig{
|
|
Logger: entry,
|
|
Listener: listener,
|
|
Target: targetAddress,
|
|
Domain: mapping.GetDomain(),
|
|
AccountID: accountID,
|
|
ServiceID: svcID,
|
|
DialFunc: dialFn,
|
|
DialTimeout: s.l4DialTimeout(mapping),
|
|
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
|
AccessLog: s.accessLog,
|
|
Filter: parseRestrictions(mapping),
|
|
Geo: s.geo,
|
|
})
|
|
relay.SetObserver(s.meter)
|
|
|
|
s.udpMu.Lock()
|
|
s.udpRelays[svcID] = relay
|
|
s.udpMu.Unlock()
|
|
|
|
s.udpRelayWg.Go(relay.Serve)
|
|
entry.Info("UDP relay added")
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
|
// Very simple implementation here, we don't touch the existing peer
|
|
// connection or any existing TLS configuration, we simply overwrite
|
|
// the auth and proxy mappings.
|
|
// Note: this does require the management server to always send a
|
|
// full mapping rather than deltas during a modification.
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
|
|
var schemes []auth.Scheme
|
|
if mapping.GetAuth().GetPassword() {
|
|
schemes = append(schemes, auth.NewPassword(s.mgmtClient, svcID, accountID))
|
|
}
|
|
if mapping.GetAuth().GetPin() {
|
|
schemes = append(schemes, auth.NewPin(s.mgmtClient, svcID, accountID))
|
|
}
|
|
if mapping.GetAuth().GetOidc() {
|
|
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
|
|
}
|
|
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
|
|
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
|
|
}
|
|
|
|
ipRestrictions := parseRestrictions(mapping)
|
|
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
|
|
|
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
|
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
|
|
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
|
|
}
|
|
m := s.protoToMapping(ctx, mapping)
|
|
s.proxy.AddMapping(m)
|
|
s.meter.AddMapping(m)
|
|
return nil
|
|
}
|
|
|
|
// removeMapping tears down routes/relays and the NetBird peer for a service.
|
|
// Uses the stored mapping state when available to ensure all previously
|
|
// configured routes are cleaned up.
|
|
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
|
accountID := types.AccountID(mapping.GetAccountId())
|
|
svcKey := s.serviceKeyForMapping(mapping)
|
|
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
|
|
s.Logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_id": mapping.GetId(),
|
|
"error": err,
|
|
}).Error("failed to remove NetBird peer, continuing cleanup")
|
|
}
|
|
|
|
if old := s.deleteMapping(types.ServiceID(mapping.GetId())); old != nil {
|
|
s.cleanupMappingRoutes(old)
|
|
if mode := types.ServiceMode(old.GetMode()); mode.IsL4() {
|
|
s.meter.L4ServiceRemoved(mode)
|
|
}
|
|
} else {
|
|
s.cleanupMappingRoutes(mapping)
|
|
}
|
|
}
|
|
|
|
// cleanupMappingRoutes removes HTTP/TLS/L4 routes and custom port state for a
|
|
// service without touching the NetBird peer. This is used for both full
|
|
// removal and in-place modification of mappings.
|
|
func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
|
|
svcID := types.ServiceID(mapping.GetId())
|
|
host := mapping.GetDomain()
|
|
|
|
// HTTP/TLS cleanup (only relevant when a domain is set).
|
|
if host != "" {
|
|
d := domain.Domain(host)
|
|
if s.acme != nil {
|
|
s.acme.RemoveDomain(d)
|
|
}
|
|
s.auth.RemoveDomain(host)
|
|
if s.proxy.RemoveMapping(proxy.Mapping{Host: host}) {
|
|
s.meter.RemoveMapping(proxy.Mapping{Host: host})
|
|
}
|
|
// Close hijacked connections (WebSocket) for this domain.
|
|
if n := s.hijackTracker.CloseByHost(host); n > 0 {
|
|
s.Logger.Debugf("closed %d hijacked connection(s) for %s", n, host)
|
|
}
|
|
// Remove SNI route from the main router (covers both HTTP and main-port TLS).
|
|
s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID)
|
|
}
|
|
|
|
// Extract and delete tracked custom-port entries atomically.
|
|
s.portMu.Lock()
|
|
entries := s.svcPorts[svcID]
|
|
delete(s.svcPorts, svcID)
|
|
s.portMu.Unlock()
|
|
|
|
for _, entry := range entries {
|
|
if router := s.routerForPortExisting(entry); router != nil {
|
|
if host != "" {
|
|
router.RemoveRoute(nbtcp.SNIHost(host), svcID)
|
|
} else {
|
|
router.RemoveFallback(svcID)
|
|
}
|
|
}
|
|
s.cleanupPortIfEmpty(entry)
|
|
}
|
|
|
|
// UDP relay cleanup (idempotent).
|
|
s.removeUDPRelay(svcID)
|
|
|
|
}
|
|
|
|
// removeUDPRelay stops and removes a UDP relay by service ID.
|
|
func (s *Server) removeUDPRelay(svcID types.ServiceID) {
|
|
s.udpMu.Lock()
|
|
relay, ok := s.udpRelays[svcID]
|
|
if ok {
|
|
delete(s.udpRelays, svcID)
|
|
}
|
|
s.udpMu.Unlock()
|
|
|
|
if ok {
|
|
relay.Close()
|
|
s.Logger.WithField("service_id", svcID).Info("UDP relay removed")
|
|
}
|
|
}
|
|
|
|
func (s *Server) storeMapping(mapping *proto.ProxyMapping) {
|
|
s.portMu.Lock()
|
|
s.lastMappings[types.ServiceID(mapping.GetId())] = mapping
|
|
s.portMu.Unlock()
|
|
}
|
|
|
|
func (s *Server) loadMapping(svcID types.ServiceID) *proto.ProxyMapping {
|
|
s.portMu.RLock()
|
|
m := s.lastMappings[svcID]
|
|
s.portMu.RUnlock()
|
|
return m
|
|
}
|
|
|
|
func (s *Server) deleteMapping(svcID types.ServiceID) *proto.ProxyMapping {
|
|
s.portMu.Lock()
|
|
m := s.lastMappings[svcID]
|
|
delete(s.lastMappings, svcID)
|
|
s.portMu.Unlock()
|
|
return m
|
|
}
|
|
|
|
func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping) proxy.Mapping {
|
|
paths := make(map[string]*proxy.PathTarget)
|
|
for _, pathMapping := range mapping.GetPath() {
|
|
targetURL, err := url.Parse(pathMapping.GetTarget())
|
|
if err != nil {
|
|
s.Logger.WithFields(log.Fields{
|
|
"service_id": mapping.GetId(),
|
|
"account_id": mapping.GetAccountId(),
|
|
"domain": mapping.GetDomain(),
|
|
"path": pathMapping.GetPath(),
|
|
"target": pathMapping.GetTarget(),
|
|
}).WithError(err).Error("failed to parse target URL for path, skipping")
|
|
s.notifyError(ctx, mapping, fmt.Errorf("invalid target URL %q for path %q: %w", pathMapping.GetTarget(), pathMapping.GetPath(), err))
|
|
continue
|
|
}
|
|
|
|
pt := &proxy.PathTarget{URL: targetURL}
|
|
if opts := pathMapping.GetOptions(); opts != nil {
|
|
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
|
|
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
|
|
pt.CustomHeaders = opts.GetCustomHeaders()
|
|
if d := opts.GetRequestTimeout(); d != nil {
|
|
pt.RequestTimeout = d.AsDuration()
|
|
}
|
|
}
|
|
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
|
|
paths[pathMapping.GetPath()] = pt
|
|
}
|
|
m := proxy.Mapping{
|
|
ID: types.ServiceID(mapping.GetId()),
|
|
AccountID: types.AccountID(mapping.GetAccountId()),
|
|
Host: mapping.GetDomain(),
|
|
Paths: paths,
|
|
PassHostHeader: mapping.GetPassHostHeader(),
|
|
RewriteRedirects: mapping.GetRewriteRedirects(),
|
|
}
|
|
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
|
|
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
|
|
}
|
|
return m
|
|
}
|
|
|
|
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
|
|
switch mode {
|
|
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
|
|
return proxy.PathRewritePreserve
|
|
default:
|
|
return proxy.PathRewriteDefault
|
|
}
|
|
}
|
|
|
|
// debugEndpointAddr returns the address for the debug endpoint.
|
|
// If addr is empty, it defaults to localhost:8444 for security.
|
|
func debugEndpointAddr(addr string) string {
|
|
if addr == "" {
|
|
return defaultDebugAddr
|
|
}
|
|
return addr
|
|
}
|