mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Add listener side proxy protocol support and enable it in traefik (#5332)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
This commit is contained in:
@@ -56,6 +56,7 @@ var (
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wgPort int
|
||||
proxyProtocol bool
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -90,6 +91,7 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -165,6 +167,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
WireguardPort: wgPort,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
106
proxy/proxyprotocol_test.go
Normal file
106
proxy/proxyprotocol_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
||||
ProxyProtocol: true,
|
||||
}
|
||||
|
||||
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer raw.Close()
|
||||
|
||||
ln := srv.wrapProxyProtocol(raw)
|
||||
|
||||
realClientIP := "203.0.113.50"
|
||||
realClientPort := uint16(54321)
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
accepted <- conn
|
||||
}()
|
||||
|
||||
// Connect and send a PROXY v2 header.
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: proxyproto.TCPv4,
|
||||
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
||||
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case accepted := <-accepted:
|
||||
defer accepted.Close()
|
||||
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
||||
}
|
||||
180
proxy/server.go
180
proxy/server.go
@@ -23,6 +23,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -92,7 +93,7 @@ type Server struct {
|
||||
DebugEndpointEnabled bool
|
||||
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
|
||||
DebugEndpointAddress string
|
||||
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
|
||||
// HealthAddress is the address for the health probe endpoint.
|
||||
HealthAddress string
|
||||
// ProxyToken is the access token for authenticating with the management server.
|
||||
ProxyToken string
|
||||
@@ -107,6 +108,10 @@ type Server struct {
|
||||
// random OS-assigned port. A fixed port only works with single-account
|
||||
// deployments; multiple accounts will fail to bind the same port.
|
||||
WireguardPort int
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
// When enabled, the real client IP is extracted from the PROXY header
|
||||
// sent by upstream L4 proxies that support PROXY protocol.
|
||||
ProxyProtocol bool
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
@@ -137,23 +142,8 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.startTime = time.Now()
|
||||
s.initDefaults()
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
if s.ID == "" {
|
||||
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
||||
}
|
||||
// Fallback version option in case it is not set.
|
||||
if s.Version == "" {
|
||||
s.Version = "dev"
|
||||
}
|
||||
|
||||
// If no logger is specified fallback to the standard logger.
|
||||
if s.Logger == nil {
|
||||
s.Logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
// Start up metrics gathering
|
||||
reg := prometheus.NewRegistry()
|
||||
s.meter = metrics.New(reg)
|
||||
|
||||
@@ -189,40 +179,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
if s.DebugEndpointEnabled {
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
s.startDebugEndpoint()
|
||||
|
||||
// Start health probe server.
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = "localhost:8080"
|
||||
if err := s.startHealthServer(reg); err != nil {
|
||||
return err
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
@@ -232,10 +193,19 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
||||
httpsErr <- s.https.ServeTLS(ln, "", "")
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -251,7 +221,115 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// initDefaults sets fallback values for optional Server fields.
|
||||
func (s *Server) initDefaults() {
|
||||
s.startTime = time.Now()
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
if s.ID == "" {
|
||||
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
||||
}
|
||||
// Fallback version option in case it is not set.
|
||||
if s.Version == "" {
|
||||
s.Version = "dev"
|
||||
}
|
||||
|
||||
// If no logger is specified fallback to the standard logger.
|
||||
if s.Logger == nil {
|
||||
s.Logger = log.StandardLogger()
|
||||
}
|
||||
}
|
||||
|
||||
// startDebugEndpoint launches the debug HTTP server if enabled.
|
||||
func (s *Server) startDebugEndpoint() {
|
||||
if !s.DebugEndpointEnabled {
|
||||
return
|
||||
}
|
||||
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
|
||||
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
|
||||
}
|
||||
go func() {
|
||||
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
|
||||
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// startHealthServer launches the health probe and metrics server.
|
||||
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = defaultHealthAddr
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapProxyProtocol wraps a listener with PROXY protocol support.
|
||||
// When TrustedProxies is configured, only those sources may send PROXY headers;
|
||||
// connections from untrusted sources have any PROXY header ignored.
|
||||
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
|
||||
ppListener := &proxyproto.Listener{
|
||||
Listener: ln,
|
||||
ReadHeaderTimeout: proxyProtoHeaderTimeout,
|
||||
}
|
||||
if len(s.TrustedProxies) > 0 {
|
||||
ppListener.ConnPolicy = s.proxyProtocolPolicy
|
||||
} else {
|
||||
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
|
||||
}
|
||||
s.Logger.Info("PROXY protocol enabled on listener")
|
||||
return ppListener
|
||||
}
|
||||
|
||||
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
|
||||
// header based on whether the connection source is in TrustedProxies.
|
||||
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
|
||||
// No logging on reject to prevent abuse
|
||||
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
return proxyproto.REJECT, nil
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
// called per accept
|
||||
for _, prefix := range s.TrustedProxies {
|
||||
if prefix.Contains(addr) {
|
||||
return proxyproto.REQUIRE, nil
|
||||
}
|
||||
}
|
||||
return proxyproto.IGNORE, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHealthAddr = "localhost:8080"
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
|
||||
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||
// header after accepting a connection.
|
||||
proxyProtoHeaderTimeout = 5 * time.Second
|
||||
|
||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||
// before draining connections. This allows the load balancer to propagate
|
||||
// the endpoint removal.
|
||||
@@ -647,7 +725,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return "localhost:8444"
|
||||
return defaultDebugAddr
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user