[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -7,6 +7,7 @@ import (
"os/signal"
"strconv"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -34,30 +35,32 @@ var (
)
var (
logLevel string
debugLogs bool
mgmtAddr string
addr string
proxyDomain string
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wildcardCertDir string
wgPort int
proxyProtocol bool
preSharedKey string
logLevel string
debugLogs bool
mgmtAddr string
addr string
proxyDomain string
defaultDialTimeout time.Duration
certDir string
acmeCerts bool
acmeAddr string
acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wildcardCertDir string
wgPort uint16
proxyProtocol bool
preSharedKey string
supportsCustomPorts bool
)
var rootCmd = &cobra.Command{
@@ -92,9 +95,11 @@ func init() {
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (<name>.crt/<name>.key). Wildcard patterns are extracted from SANs automatically")
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
rootCmd.Flags().Uint16Var(&wgPort, "wg-port", envUint16OrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)")
}
// Execute runs the root command.
@@ -171,6 +176,8 @@ func runServer(cmd *cobra.Command, args []string) error {
WireguardPort: wgPort,
ProxyProtocol: proxyProtocol,
PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts,
DefaultDialTimeout: defaultDialTimeout,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
@@ -203,12 +210,24 @@ func envStringOrDefault(key string, def string) string {
return v
}
func envIntOrDefault(key string, def int) int {
func envUint16OrDefault(key string, def uint16) uint16 {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
parsed, err := strconv.Atoi(v)
parsed, err := strconv.ParseUint(v, 10, 16)
if err != nil {
return def
}
return uint16(parsed)
}
func envDurationOrDefault(key string, def time.Duration) time.Duration {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
parsed, err := time.ParseDuration(v)
if err != nil {
return def
}

View File

@@ -38,11 +38,18 @@ func (m *mockMappingStream) Context() context.Context { return context.Backgroun
func (m *mockMappingStream) SendMsg(any) error { return nil }
func (m *mockMappingStream) RecvMsg(any) error { return nil }
func closedChan() chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
}
stream := &mockMappingStream{
@@ -62,6 +69,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
}
stream := &mockMappingStream{
@@ -78,7 +86,8 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
Logger: log.StandardLogger(),
routerReady: closedChan(),
}
stream := &mockMappingStream{

View File

@@ -6,11 +6,13 @@ import (
"sync"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -19,6 +21,7 @@ const (
bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
logSendTimeout = 10 * time.Second
)
type domainUsage struct {
@@ -79,22 +82,63 @@ func (l *Logger) Close() {
type logEntry struct {
ID string
AccountID string
ServiceId string
AccountID types.AccountID
ServiceId types.ServiceID
Host string
Path string
DurationMs int64
Method string
ResponseCode int32
SourceIp string
SourceIP netip.Addr
AuthMechanism string
UserId string
AuthSuccess bool
BytesUpload int64
BytesDownload int64
Protocol Protocol
}
func (l *Logger) log(ctx context.Context, entry logEntry) {
// Protocol identifies the transport protocol of an access log entry.
type Protocol string
const (
ProtocolHTTP Protocol = "http"
ProtocolTCP Protocol = "tcp"
ProtocolUDP Protocol = "udp"
ProtocolTLS Protocol = "tls"
)
// L4Entry holds the data for a layer-4 (TCP/UDP) access log entry.
type L4Entry struct {
AccountID types.AccountID
ServiceID types.ServiceID
Protocol Protocol
Host string // SNI hostname or listen address
SourceIP netip.Addr
DurationMs int64
BytesUpload int64
BytesDownload int64
}
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
// The call is non-blocking: the gRPC send happens in a background goroutine.
func (l *Logger) LogL4(entry L4Entry) {
le := logEntry{
ID: xid.New().String(),
AccountID: entry.AccountID,
ServiceId: entry.ServiceID,
Protocol: entry.Protocol,
Host: entry.Host,
SourceIP: entry.SourceIP,
DurationMs: entry.DurationMs,
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
}
l.log(le)
l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload)
}
func (l *Logger) log(entry logEntry) {
// Fire off the log request in a separate routine.
// This increases the possibility of losing a log message
// (although it should still get logged in the event of an error),
@@ -105,31 +149,37 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
// allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
go func() {
logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout)
defer cancel()
if entry.AuthMechanism != auth.MethodOIDC.String() {
entry.UserId = ""
}
var sourceIP string
if entry.SourceIP.IsValid() {
sourceIP = entry.SourceIP.String()
}
if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{
Log: &proto.AccessLog{
LogId: entry.ID,
AccountId: entry.AccountID,
AccountId: string(entry.AccountID),
Timestamp: now,
ServiceId: entry.ServiceId,
ServiceId: string(entry.ServiceId),
Host: entry.Host,
Path: entry.Path,
DurationMs: entry.DurationMs,
Method: entry.Method,
ResponseCode: entry.ResponseCode,
SourceIp: entry.SourceIp,
SourceIp: sourceIP,
AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId,
AuthSuccess: entry.AuthSuccess,
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
Protocol: string(entry.Protocol),
},
}); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log.
l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId,
"host": entry.Host,
@@ -137,7 +187,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
"duration": entry.DurationMs,
"method": entry.Method,
"response_code": entry.ResponseCode,
"source_ip": entry.SourceIp,
"source_ip": sourceIP,
"auth_mechanism": entry.AuthMechanism,
"user_id": entry.UserId,
"auth_success": entry.AuthSuccess,

View File

@@ -67,23 +67,24 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
entry := logEntry{
ID: requestID,
ServiceId: capturedData.GetServiceId(),
AccountID: string(capturedData.GetAccountId()),
AccountID: capturedData.GetAccountId(),
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(sw.status),
SourceIp: sourceIp,
SourceIP: sourceIp,
AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
BytesUpload: bytesUpload,
BytesDownload: bytesDownload,
Protocol: ProtocolHTTP,
}
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
l.log(r.Context(), entry)
l.log(entry)
// Track usage for cost monitoring (upload + download) by domain
l.trackUsage(host, bytesUpload+bytesDownload)

View File

@@ -11,6 +11,6 @@ import (
// proxy configuration. When trustedProxies is non-empty and the direct
// connection is from a trusted source, it walks X-Forwarded-For right-to-left
// skipping trusted IPs. Otherwise it returns RemoteAddr directly.
func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string {
func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) netip.Addr {
return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies)
}

View File

@@ -23,6 +23,7 @@ import (
"golang.org/x/crypto/acme/autocert"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -30,7 +31,7 @@ import (
var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2}
type certificateNotifier interface {
NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error
NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error
}
type domainState int
@@ -42,8 +43,8 @@ const (
)
type domainInfo struct {
accountID string
serviceID string
accountID types.AccountID
serviceID types.ServiceID
state domainState
err string
}
@@ -301,7 +302,7 @@ func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate
// When AddDomain returns true the caller is responsible for sending any
// certificate-ready notifications after the surrounding operation (e.g.
// mapping update) has committed successfully.
func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) (wildcardHit bool) {
func (mgr *Manager) AddDomain(d domain.Domain, accountID types.AccountID, serviceID types.ServiceID) (wildcardHit bool) {
name := d.PunycodeString()
if e := mgr.findWildcardEntry(name); e != nil {
mgr.mu.Lock()

View File

@@ -17,12 +17,14 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestHostPolicy(t *testing.T) {
mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil)
require.NoError(t, err)
mgr.AddDomain("example.com", "acc1", "rp1")
mgr.AddDomain("example.com", types.AccountID("acc1"), types.ServiceID("rp1"))
// Wait for the background prefetch goroutine to finish so the temp dir
// can be cleaned up without a race.
@@ -92,8 +94,8 @@ func TestDomainStates(t *testing.T) {
// AddDomain starts as pending, then the prefetch goroutine will fail
// (no real ACME server) and transition to failed.
mgr.AddDomain("a.example.com", "acc1", "rp1")
mgr.AddDomain("b.example.com", "acc1", "rp1")
mgr.AddDomain("a.example.com", types.AccountID("acc1"), types.ServiceID("rp1"))
mgr.AddDomain("b.example.com", types.AccountID("acc1"), types.ServiceID("rp1"))
assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered")
@@ -209,12 +211,12 @@ func TestWildcardAddDomainSkipsACME(t *testing.T) {
require.NoError(t, err)
// Add a wildcard-matching domain — should be immediately ready.
mgr.AddDomain("foo.example.com", "acc1", "svc1")
mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1"))
assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending")
assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains())
// Add a non-wildcard domain — should go through ACME (pending then failed).
mgr.AddDomain("other.net", "acc2", "svc2")
mgr.AddDomain("other.net", types.AccountID("acc2"), types.ServiceID("svc2"))
assert.Equal(t, 2, mgr.TotalDomains())
// Wait for the ACME prefetch to fail.
@@ -234,7 +236,7 @@ func TestWildcardGetCertificate(t *testing.T) {
mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil)
require.NoError(t, err)
mgr.AddDomain("foo.example.com", "acc1", "svc1")
mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1"))
// GetCertificate for a wildcard-matching domain should return the static cert.
cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
@@ -255,8 +257,8 @@ func TestMultipleWildcards(t *testing.T) {
assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns())
// Both wildcards should resolve.
mgr.AddDomain("foo.example.com", "acc1", "svc1")
mgr.AddDomain("bar.other.org", "acc2", "svc2")
mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1"))
mgr.AddDomain("bar.other.org", types.AccountID("acc2"), types.ServiceID("svc2"))
assert.Equal(t, 0, mgr.PendingCerts())
assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains())
@@ -271,7 +273,7 @@ func TestMultipleWildcards(t *testing.T) {
assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org")
// Non-matching domain falls through to ACME.
mgr.AddDomain("custom.net", "acc3", "svc3")
mgr.AddDomain("custom.net", types.AccountID("acc3"), types.ServiceID("svc3"))
assert.Eventually(t, func() bool {
return mgr.PendingCerts() == 0
}, 30*time.Second, 100*time.Millisecond)

View File

@@ -44,8 +44,8 @@ type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
AccountID string
ServiceID string
AccountID types.AccountID
ServiceID types.ServiceID
}
type validationResult struct {
@@ -124,7 +124,7 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetAccountId(config.AccountID)
cd.SetServiceId(config.ServiceID)
}
}
@@ -275,7 +275,7 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error {
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()

View File

@@ -9,6 +9,7 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -17,14 +18,14 @@ type urlGenerator interface {
}
type OIDC struct {
id string
accountId string
id types.ServiceID
accountId types.AccountID
forwardedProto string
client urlGenerator
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC {
func NewOIDC(client urlGenerator, id types.ServiceID, accountId types.AccountID, forwardedProto string) OIDC {
return OIDC{
id: id,
accountId: accountId,
@@ -53,8 +54,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string, error) {
}
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
Id: o.id,
AccountId: o.accountId,
Id: string(o.id),
AccountId: string(o.accountId),
RedirectUrl: redirectURL.String(),
})
if err != nil {

View File

@@ -5,17 +5,19 @@ import (
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
const passwordFormId = "password"
type Password struct {
id, accountId string
client authenticator
id types.ServiceID
accountId types.AccountID
client authenticator
}
func NewPassword(client authenticator, id, accountId string) Password {
func NewPassword(client authenticator, id types.ServiceID, accountId types.AccountID) Password {
return Password{
id: id,
accountId: accountId,
@@ -41,8 +43,8 @@ func (p Password) Authenticate(r *http.Request) (string, string, error) {
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: p.id,
AccountId: p.accountId,
Id: string(p.id),
AccountId: string(p.accountId),
Request: &proto.AuthenticateRequest_Password{
Password: &proto.PasswordRequest{
Password: password,

View File

@@ -5,17 +5,19 @@ import (
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
const pinFormId = "pin"
type Pin struct {
id, accountId string
client authenticator
id types.ServiceID
accountId types.AccountID
client authenticator
}
func NewPin(client authenticator, id, accountId string) Pin {
func NewPin(client authenticator, id types.ServiceID, accountId types.AccountID) Pin {
return Pin{
id: id,
accountId: accountId,
@@ -41,8 +43,8 @@ func (p Pin) Authenticate(r *http.Request) (string, string, error) {
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: p.id,
AccountId: p.accountId,
Id: string(p.id),
AccountId: string(p.accountId),
Request: &proto.AuthenticateRequest_Pin{
Pin: &proto.PinRequest{
Pin: pin,

View File

@@ -10,10 +10,11 @@ import (
type trackedConn struct {
net.Conn
tracker *HijackTracker
host string
}
func (c *trackedConn) Close() error {
c.tracker.conns.Delete(c)
c.tracker.remove(c)
return c.Conn.Close()
}
@@ -22,6 +23,7 @@ func (c *trackedConn) Close() error {
type trackingWriter struct {
http.ResponseWriter
tracker *HijackTracker
host string
}
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
@@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if err != nil {
return nil, nil, err
}
tc := &trackedConn{Conn: conn, tracker: w.tracker}
w.tracker.conns.Store(tc, struct{}{})
tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host}
w.tracker.add(tc)
return tc, buf, nil
}

View File

@@ -1,7 +1,6 @@
package conntrack
import (
"net"
"net/http"
"sync"
)
@@ -10,10 +9,14 @@ import (
// upgrades). http.Server.Shutdown does not close hijacked connections, so
// they must be tracked and closed explicitly during graceful shutdown.
//
// Connections are indexed by the request Host so they can be closed
// per-domain when a service mapping is removed.
//
// Use Middleware as the outermost HTTP middleware to ensure hijacked
// connections are tracked and automatically deregistered when closed.
type HijackTracker struct {
conns sync.Map // net.Conn → struct{}
mu sync.Mutex
conns map[*trackedConn]struct{}
}
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
@@ -21,21 +24,73 @@ type HijackTracker struct {
// tracker when closed. This should be the outermost middleware in the chain.
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
next.ServeHTTP(&trackingWriter{
ResponseWriter: w,
tracker: t,
host: hostOnly(r.Host),
}, r)
})
}
// CloseAll closes all tracked hijacked connections and returns the number
// of connections that were closed.
// CloseAll closes all tracked hijacked connections and returns the count.
func (t *HijackTracker) CloseAll() int {
var count int
t.conns.Range(func(key, _ any) bool {
if conn, ok := key.(net.Conn); ok {
_ = conn.Close()
count++
}
t.conns.Delete(key)
return true
})
return count
t.mu.Lock()
conns := t.conns
t.conns = nil
t.mu.Unlock()
for tc := range conns {
_ = tc.Conn.Close()
}
return len(conns)
}
// CloseByHost closes all tracked hijacked connections for the given host
// and returns the number of connections closed.
func (t *HijackTracker) CloseByHost(host string) int {
host = hostOnly(host)
t.mu.Lock()
var toClose []*trackedConn
for tc := range t.conns {
if tc.host == host {
toClose = append(toClose, tc)
}
}
for _, tc := range toClose {
delete(t.conns, tc)
}
t.mu.Unlock()
for _, tc := range toClose {
_ = tc.Conn.Close()
}
return len(toClose)
}
func (t *HijackTracker) add(tc *trackedConn) {
t.mu.Lock()
if t.conns == nil {
t.conns = make(map[*trackedConn]struct{})
}
t.conns[tc] = struct{}{}
t.mu.Unlock()
}
func (t *HijackTracker) remove(tc *trackedConn) {
t.mu.Lock()
delete(t.conns, tc)
t.mu.Unlock()
}
// hostOnly strips the port from a host:port string.
func hostOnly(hostport string) string {
for i := len(hostport) - 1; i >= 0; i-- {
if hostport[i] == ':' {
return hostport[:i]
}
if hostport[i] < '0' || hostport[i] > '9' {
return hostport
}
}
return hostport
}

View File

@@ -0,0 +1,142 @@
package conntrack
import (
"bufio"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing.
type fakeHijackWriter struct {
http.ResponseWriter
conn net.Conn
}
func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil
}
func TestCloseByHost(t *testing.T) {
var tracker HijackTracker
// Simulate hijacking two connections for different hosts.
connA1, connA2 := net.Pipe()
defer connA2.Close()
connB1, connB2 := net.Pipe()
defer connB2.Close()
twA := &trackingWriter{
ResponseWriter: httptest.NewRecorder(),
tracker: &tracker,
host: "a.example.com",
}
twB := &trackingWriter{
ResponseWriter: httptest.NewRecorder(),
tracker: &tracker,
host: "b.example.com",
}
// Use fakeHijackWriter to provide the Hijack method.
twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1}
twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1}
_, _, err := twA.Hijack()
require.NoError(t, err)
_, _, err = twB.Hijack()
require.NoError(t, err)
tracker.mu.Lock()
assert.Equal(t, 2, len(tracker.conns), "should track 2 connections")
tracker.mu.Unlock()
// Close only host A.
n := tracker.CloseByHost("a.example.com")
assert.Equal(t, 1, n, "should close 1 connection for host A")
tracker.mu.Lock()
assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection")
tracker.mu.Unlock()
// Verify host A's conn is actually closed.
buf := make([]byte, 1)
_, err = connA2.Read(buf)
assert.Error(t, err, "host A pipe should be closed")
// Host B should still be alive.
go func() { _, _ = connB1.Write([]byte("x")) }()
// Close all remaining.
n = tracker.CloseAll()
assert.Equal(t, 1, n, "should close remaining 1 connection")
tracker.mu.Lock()
assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll")
tracker.mu.Unlock()
}
func TestCloseAll(t *testing.T) {
var tracker HijackTracker
for range 5 {
c1, c2 := net.Pipe()
defer c2.Close()
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"}
tracker.add(tc)
}
tracker.mu.Lock()
assert.Equal(t, 5, len(tracker.conns))
tracker.mu.Unlock()
n := tracker.CloseAll()
assert.Equal(t, 5, n)
// Double CloseAll is safe.
n = tracker.CloseAll()
assert.Equal(t, 0, n)
}
func TestTrackedConn_AutoDeregister(t *testing.T) {
var tracker HijackTracker
c1, c2 := net.Pipe()
defer c2.Close()
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"}
tracker.add(tc)
tracker.mu.Lock()
assert.Equal(t, 1, len(tracker.conns))
tracker.mu.Unlock()
// Close the tracked conn: should auto-deregister.
require.NoError(t, tc.Close())
tracker.mu.Lock()
assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close")
tracker.mu.Unlock()
}
func TestHostOnly(t *testing.T) {
tests := []struct {
input string
want string
}{
{"example.com:443", "example.com"},
{"example.com", "example.com"},
{"127.0.0.1:8080", "127.0.0.1"},
{"[::1]:443", "[::1]"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
assert.Equal(t, tt.want, hostOnly(tt.input))
})
}
}

View File

@@ -152,7 +152,7 @@ func (c *Client) printClients(data map[string]any) {
return
}
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "SERVICES", "HAS CLIENT")
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
for _, item := range clients {
@@ -166,7 +166,7 @@ func (c *Client) printClientRow(item any) {
return
}
domains := c.extractDomains(client)
services := c.extractServiceKeys(client)
hasClient := "no"
if hc, ok := client["has_client"].(bool); ok && hc {
hasClient = "yes"
@@ -175,20 +175,20 @@ func (c *Client) printClientRow(item any) {
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
client["account_id"],
client["age"],
domains,
services,
hasClient,
)
}
func (c *Client) extractDomains(client map[string]any) string {
d, ok := client["domains"].([]any)
func (c *Client) extractServiceKeys(client map[string]any) string {
d, ok := client["service_keys"].([]any)
if !ok || len(d) == 0 {
return "-"
}
parts := make([]string, len(d))
for i, domain := range d {
parts[i] = fmt.Sprint(domain)
for i, key := range d {
parts[i] = fmt.Sprint(key)
}
return strings.Join(parts, ", ")
}

View File

@@ -189,7 +189,7 @@ type indexData struct {
Version string
Uptime string
ClientCount int
TotalDomains int
TotalServices int
CertsTotal int
CertsReady int
CertsPending int
@@ -202,7 +202,7 @@ type indexData struct {
type clientData struct {
AccountID string
Domains string
Services string
Age string
Status string
}
@@ -211,9 +211,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
clients := h.provider.ListClientsForDebug()
sortedIDs := sortedAccountIDs(clients)
totalDomains := 0
totalServices := 0
for _, info := range clients {
totalDomains += info.DomainCount
totalServices += info.ServiceCount
}
var certsTotal, certsReady, certsPending, certsFailed int
@@ -234,24 +234,24 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"domain_count": info.DomainCount,
"domains": info.Domains,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
resp := map[string]interface{}{
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"total_domains": totalDomains,
"certs_total": certsTotal,
"certs_ready": certsReady,
"certs_pending": certsPending,
"certs_failed": certsFailed,
"clients": clientsJSON,
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"total_services": totalServices,
"certs_total": certsTotal,
"certs_ready": certsReady,
"certs_pending": certsPending,
"certs_failed": certsFailed,
"clients": clientsJSON,
}
if len(certsPendingDomains) > 0 {
resp["certs_pending_domains"] = certsPendingDomains
@@ -278,7 +278,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
Version: version.NetbirdVersion(),
Uptime: time.Since(h.startTime).Round(time.Second).String(),
ClientCount: len(clients),
TotalDomains: totalDomains,
TotalServices: totalServices,
CertsTotal: certsTotal,
CertsReady: certsReady,
CertsPending: certsPending,
@@ -291,9 +291,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
for _, id := range sortedIDs {
info := clients[id]
domains := info.Domains.SafeString()
if domains == "" {
domains = "-"
services := strings.Join(info.ServiceKeys, ", ")
if services == "" {
services = "-"
}
status := "No client"
if info.HasClient {
@@ -301,7 +301,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
}
data.Clients = append(data.Clients, clientData{
AccountID: string(info.AccountID),
Domains: domains,
Services: services,
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
Status: status,
})
@@ -324,12 +324,12 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"domain_count": info.DomainCount,
"domains": info.Domains,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
h.writeJSON(w, map[string]interface{}{
@@ -347,9 +347,9 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
for _, id := range sortedIDs {
info := clients[id]
domains := info.Domains.SafeString()
if domains == "" {
domains = "-"
services := strings.Join(info.ServiceKeys, ", ")
if services == "" {
services = "-"
}
status := "No client"
if info.HasClient {
@@ -357,7 +357,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
}
data.Clients = append(data.Clients, clientData{
AccountID: string(info.AccountID),
Domains: domains,
Services: services,
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
Status: status,
})

View File

@@ -12,14 +12,14 @@
<table>
<tr>
<th>Account ID</th>
<th>Domains</th>
<th>Services</th>
<th>Age</th>
<th>Status</th>
</tr>
{{range .Clients}}
<tr>
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
<td>{{.Domains}}</td>
<td>{{.Services}}</td>
<td>{{.Age}}</td>
<td>{{.Status}}</td>
</tr>

View File

@@ -27,19 +27,19 @@
<ul>{{range .CertsFailedDomains}}<li>{{.Domain}}: {{.Error}}</li>{{end}}</ul>
</details>
{{end}}
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
<h2>Clients ({{.ClientCount}}) | Services ({{.TotalServices}})</h2>
{{if .Clients}}
<table>
<tr>
<th>Account ID</th>
<th>Domains</th>
<th>Services</th>
<th>Age</th>
<th>Status</th>
</tr>
{{range .Clients}}
<tr>
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
<td>{{.Domains}}</td>
<td>{{.Services}}</td>
<td>{{.Age}}</td>
<td>{{.Status}}</td>
</tr>

View File

@@ -0,0 +1,69 @@
package metrics_test
import (
"context"
"reflect"
"testing"
"time"
promexporter "go.opentelemetry.io/otel/exporters/prometheus"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func newTestMetrics(t *testing.T) *metrics.Metrics {
t.Helper()
exporter, err := promexporter.New()
if err != nil {
t.Fatalf("create prometheus exporter: %v", err)
}
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter))
pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath()
meter := provider.Meter(pkg)
m, err := metrics.New(context.Background(), meter)
if err != nil {
t.Fatalf("create metrics: %v", err)
}
return m
}
func TestL4ServiceGauge(t *testing.T) {
m := newTestMetrics(t)
m.L4ServiceAdded(types.ServiceModeTCP)
m.L4ServiceAdded(types.ServiceModeTCP)
m.L4ServiceAdded(types.ServiceModeUDP)
m.L4ServiceRemoved(types.ServiceModeTCP)
}
func TestTCPRelayMetrics(t *testing.T) {
m := newTestMetrics(t)
acct := types.AccountID("acct-1")
m.TCPRelayStarted(acct)
m.TCPRelayStarted(acct)
m.TCPRelayEnded(acct, 10*time.Second, 1000, 500)
m.TCPRelayDialError(acct)
m.TCPRelayRejected(acct)
}
func TestUDPSessionMetrics(t *testing.T) {
m := newTestMetrics(t)
acct := types.AccountID("acct-2")
m.UDPSessionStarted(acct)
m.UDPSessionStarted(acct)
m.UDPSessionEnded(acct)
m.UDPSessionDialError(acct)
m.UDPSessionRejected(acct)
m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 100)
m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 200)
m.UDPPacketRelayed(types.RelayDirectionBackendToClient, 150)
}

View File

@@ -6,12 +6,15 @@ import (
"sync"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// Metrics collects OpenTelemetry metrics for the proxy.
type Metrics struct {
ctx context.Context
requestsTotal metric.Int64Counter
@@ -22,85 +25,188 @@ type Metrics struct {
backendDuration metric.Int64Histogram
certificateIssueDuration metric.Int64Histogram
// L4 service-level metrics.
l4Services metric.Int64UpDownCounter
// L4 TCP connection-level metrics.
tcpActiveConns metric.Int64UpDownCounter
tcpConnsTotal metric.Int64Counter
tcpConnDuration metric.Int64Histogram
tcpBytesTotal metric.Int64Counter
// L4 UDP session-level metrics.
udpActiveSess metric.Int64UpDownCounter
udpSessionsTotal metric.Int64Counter
udpPacketsTotal metric.Int64Counter
udpBytesTotal metric.Int64Counter
mappingsMux sync.Mutex
mappingPaths map[string]int
}
// New creates a Metrics instance using the given OpenTelemetry meter.
func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
requestsTotal, err := meter.Int64Counter(
m := &Metrics{
ctx: ctx,
mappingPaths: make(map[string]int),
}
if err := m.initHTTPMetrics(meter); err != nil {
return nil, err
}
if err := m.initL4Metrics(meter); err != nil {
return nil, err
}
return m, nil
}
func (m *Metrics) initHTTPMetrics(meter metric.Meter) error {
var err error
m.requestsTotal, err = meter.Int64Counter(
"proxy.http.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Total number of requests made to the netbird proxy"),
)
if err != nil {
return nil, err
return err
}
activeRequests, err := meter.Int64UpDownCounter(
m.activeRequests, err = meter.Int64UpDownCounter(
"proxy.http.active_requests",
metric.WithUnit("1"),
metric.WithDescription("Current in-flight requests handled by the netbird proxy"),
)
if err != nil {
return nil, err
return err
}
configuredDomains, err := meter.Int64UpDownCounter(
m.configuredDomains, err = meter.Int64UpDownCounter(
"proxy.domains.count",
metric.WithUnit("1"),
metric.WithDescription("Current number of domains configured on the netbird proxy"),
)
if err != nil {
return nil, err
return err
}
totalPaths, err := meter.Int64UpDownCounter(
m.totalPaths, err = meter.Int64UpDownCounter(
"proxy.paths.count",
metric.WithUnit("1"),
metric.WithDescription("Total number of paths configured on the netbird proxy"),
)
if err != nil {
return nil, err
return err
}
requestDuration, err := meter.Int64Histogram(
m.requestDuration, err = meter.Int64Histogram(
"proxy.http.request.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of requests made to the netbird proxy"),
)
if err != nil {
return nil, err
return err
}
backendDuration, err := meter.Int64Histogram(
m.backendDuration, err = meter.Int64Histogram(
"proxy.backend.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of peer round trip time from the netbird proxy"),
)
if err != nil {
return nil, err
return err
}
certificateIssueDuration, err := meter.Int64Histogram(
m.certificateIssueDuration, err = meter.Int64Histogram(
"proxy.certificate.issue.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of ACME certificate issuance"),
)
return err
}
func (m *Metrics) initL4Metrics(meter metric.Meter) error {
var err error
m.l4Services, err = meter.Int64UpDownCounter(
"proxy.l4.services.count",
metric.WithUnit("1"),
metric.WithDescription("Current number of configured L4 services (TCP/TLS/UDP) by mode"),
)
if err != nil {
return nil, err
return err
}
return &Metrics{
ctx: ctx,
requestsTotal: requestsTotal,
activeRequests: activeRequests,
configuredDomains: configuredDomains,
totalPaths: totalPaths,
requestDuration: requestDuration,
backendDuration: backendDuration,
certificateIssueDuration: certificateIssueDuration,
mappingPaths: make(map[string]int),
}, nil
m.tcpActiveConns, err = meter.Int64UpDownCounter(
"proxy.tcp.active_connections",
metric.WithUnit("1"),
metric.WithDescription("Current number of active TCP/TLS relay connections"),
)
if err != nil {
return err
}
m.tcpConnsTotal, err = meter.Int64Counter(
"proxy.tcp.connections.total",
metric.WithUnit("1"),
metric.WithDescription("Total TCP/TLS relay connections by result and account"),
)
if err != nil {
return err
}
m.tcpConnDuration, err = meter.Int64Histogram(
"proxy.tcp.connection.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of TCP/TLS relay connections"),
)
if err != nil {
return err
}
m.tcpBytesTotal, err = meter.Int64Counter(
"proxy.tcp.bytes.total",
metric.WithUnit("bytes"),
metric.WithDescription("Total bytes transferred through TCP/TLS relay by direction"),
)
if err != nil {
return err
}
m.udpActiveSess, err = meter.Int64UpDownCounter(
"proxy.udp.active_sessions",
metric.WithUnit("1"),
metric.WithDescription("Current number of active UDP relay sessions"),
)
if err != nil {
return err
}
m.udpSessionsTotal, err = meter.Int64Counter(
"proxy.udp.sessions.total",
metric.WithUnit("1"),
metric.WithDescription("Total UDP relay sessions by result and account"),
)
if err != nil {
return err
}
m.udpPacketsTotal, err = meter.Int64Counter(
"proxy.udp.packets.total",
metric.WithUnit("1"),
metric.WithDescription("Total UDP packets relayed by direction"),
)
if err != nil {
return err
}
m.udpBytesTotal, err = meter.Int64Counter(
"proxy.udp.bytes.total",
metric.WithUnit("bytes"),
metric.WithDescription("Total bytes transferred through UDP relay by direction"),
)
return err
}
type responseInterceptor struct {
@@ -120,6 +226,13 @@ func (w *responseInterceptor) Write(b []byte) (int, error) {
return size, err
}
// Unwrap returns the underlying ResponseWriter so http.ResponseController
// can reach through to the original writer for Hijack/Flush operations.
func (w *responseInterceptor) Unwrap() http.ResponseWriter {
return w.PassthroughWriter
}
// Middleware wraps an HTTP handler with request metrics.
func (m *Metrics) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.requestsTotal.Add(m.ctx, 1)
@@ -144,6 +257,7 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
// RoundTripper wraps an http.RoundTripper with backend duration metrics.
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
start := time.Now()
@@ -156,6 +270,7 @@ func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
})
}
// AddMapping records that a domain mapping was added.
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
m.mappingsMux.Lock()
defer m.mappingsMux.Unlock()
@@ -175,13 +290,13 @@ func (m *Metrics) AddMapping(mapping proxy.Mapping) {
m.mappingPaths[mapping.Host] = newPathCount
}
// RemoveMapping records that a domain mapping was removed.
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
m.mappingsMux.Lock()
defer m.mappingsMux.Unlock()
oldPathCount, exists := m.mappingPaths[mapping.Host]
if !exists {
// Nothing to remove
return
}
@@ -195,3 +310,80 @@ func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
func (m *Metrics) RecordCertificateIssuance(duration time.Duration) {
m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds())
}
// L4ServiceAdded increments the L4 service gauge for the given mode.
func (m *Metrics) L4ServiceAdded(mode types.ServiceMode) {
m.l4Services.Add(m.ctx, 1, metric.WithAttributes(attribute.String("mode", string(mode))))
}
// L4ServiceRemoved decrements the L4 service gauge for the given mode.
func (m *Metrics) L4ServiceRemoved(mode types.ServiceMode) {
m.l4Services.Add(m.ctx, -1, metric.WithAttributes(attribute.String("mode", string(mode))))
}
// TCPRelayStarted records a new TCP relay connection starting.
func (m *Metrics) TCPRelayStarted(accountID types.AccountID) {
acct := attribute.String("account_id", string(accountID))
m.tcpActiveConns.Add(m.ctx, 1, metric.WithAttributes(acct))
m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success")))
}
// TCPRelayEnded records a TCP relay connection ending and accumulates bytes and duration.
func (m *Metrics) TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) {
acct := attribute.String("account_id", string(accountID))
m.tcpActiveConns.Add(m.ctx, -1, metric.WithAttributes(acct))
m.tcpConnDuration.Record(m.ctx, duration.Milliseconds(), metric.WithAttributes(acct))
m.tcpBytesTotal.Add(m.ctx, srcToDst, metric.WithAttributes(attribute.String("direction", "client_to_backend")))
m.tcpBytesTotal.Add(m.ctx, dstToSrc, metric.WithAttributes(attribute.String("direction", "backend_to_client")))
}
// TCPRelayDialError records a dial failure for a TCP relay.
func (m *Metrics) TCPRelayDialError(accountID types.AccountID) {
m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(
attribute.String("account_id", string(accountID)),
attribute.String("result", "dial_error"),
))
}
// TCPRelayRejected records a rejected TCP relay (semaphore full).
func (m *Metrics) TCPRelayRejected(accountID types.AccountID) {
m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(
attribute.String("account_id", string(accountID)),
attribute.String("result", "rejected"),
))
}
// UDPSessionStarted records a new UDP session starting.
func (m *Metrics) UDPSessionStarted(accountID types.AccountID) {
acct := attribute.String("account_id", string(accountID))
m.udpActiveSess.Add(m.ctx, 1, metric.WithAttributes(acct))
m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success")))
}
// UDPSessionEnded records a UDP session ending.
func (m *Metrics) UDPSessionEnded(accountID types.AccountID) {
m.udpActiveSess.Add(m.ctx, -1, metric.WithAttributes(attribute.String("account_id", string(accountID))))
}
// UDPSessionDialError records a dial failure for a UDP session.
func (m *Metrics) UDPSessionDialError(accountID types.AccountID) {
m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(
attribute.String("account_id", string(accountID)),
attribute.String("result", "dial_error"),
))
}
// UDPSessionRejected records a rejected UDP session (limit or rate limited).
func (m *Metrics) UDPSessionRejected(accountID types.AccountID) {
m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(
attribute.String("account_id", string(accountID)),
attribute.String("result", "rejected"),
))
}
// UDPPacketRelayed records a packet relayed in the given direction with its size in bytes.
func (m *Metrics) UDPPacketRelayed(direction types.RelayDirection, bytes int) {
dir := attribute.String("direction", string(direction))
m.udpPacketsTotal.Add(m.ctx, 1, metric.WithAttributes(dir))
m.udpBytesTotal.Add(m.ctx, int64(bytes), metric.WithAttributes(dir))
}

View File

@@ -0,0 +1,40 @@
package netutil
import (
"context"
"errors"
"fmt"
"io"
"math"
"net"
"syscall"
)
// ValidatePort converts an int32 proto port to uint16, returning an error
// if the value is out of the valid 165535 range.
func ValidatePort(port int32) (uint16, error) {
if port <= 0 || port > math.MaxUint16 {
return 0, fmt.Errorf("invalid port %d: must be 165535", port)
}
return uint16(port), nil
}
// IsExpectedError returns true for errors that are normal during
// connection teardown and should not be logged as warnings.
func IsExpectedError(err error) bool {
return errors.Is(err, net.ErrClosed) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, io.EOF) ||
errors.Is(err, syscall.ECONNRESET) ||
errors.Is(err, syscall.EPIPE) ||
errors.Is(err, syscall.ECONNABORTED)
}
// IsTimeout checks whether the error is a network timeout.
func IsTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

View File

@@ -0,0 +1,92 @@
package netutil
import (
"context"
"errors"
"fmt"
"io"
"net"
"syscall"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidatePort(t *testing.T) {
tests := []struct {
name string
port int32
want uint16
wantErr bool
}{
{"valid min", 1, 1, false},
{"valid mid", 8080, 8080, false},
{"valid max", 65535, 65535, false},
{"zero", 0, 0, true},
{"negative", -1, 0, true},
{"too large", 65536, 0, true},
{"way too large", 100000, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ValidatePort(tt.port)
if tt.wantErr {
assert.Error(t, err)
assert.Zero(t, got)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func TestIsExpectedError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"net.ErrClosed", net.ErrClosed, true},
{"context.Canceled", context.Canceled, true},
{"io.EOF", io.EOF, true},
{"ECONNRESET", syscall.ECONNRESET, true},
{"EPIPE", syscall.EPIPE, true},
{"ECONNABORTED", syscall.ECONNABORTED, true},
{"wrapped expected", fmt.Errorf("wrap: %w", net.ErrClosed), true},
{"unexpected EOF", io.ErrUnexpectedEOF, false},
{"generic error", errors.New("something"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, IsExpectedError(tt.err))
})
}
}
type timeoutErr struct{ timeout bool }
func (e *timeoutErr) Error() string { return "timeout" }
func (e *timeoutErr) Timeout() bool { return e.timeout }
func (e *timeoutErr) Temporary() bool { return false }
func TestIsTimeout(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"net timeout", &timeoutErr{timeout: true}, true},
{"net non-timeout", &timeoutErr{timeout: false}, false},
{"wrapped timeout", fmt.Errorf("wrap: %w", &timeoutErr{timeout: true}), true},
{"generic error", errors.New("not a timeout"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, IsTimeout(tt.err))
})
}
}

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"net/netip"
"sync"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string {
type CapturedData struct {
mu sync.RWMutex
RequestID string
ServiceId string
ServiceId types.ServiceID
AccountId types.AccountID
Origin ResponseOrigin
ClientIP string
ClientIP netip.Addr
UserID string
AuthMethod string
}
@@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string {
}
// SetServiceId safely sets the service ID
func (c *CapturedData) SetServiceId(serviceId string) {
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
c.mu.Lock()
defer c.mu.Unlock()
c.ServiceId = serviceId
}
// GetServiceId safely gets the service ID
func (c *CapturedData) GetServiceId() string {
func (c *CapturedData) GetServiceId() types.ServiceID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ServiceId
@@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin {
}
// SetClientIP safely sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip string) {
func (c *CapturedData) SetClientIP(ip netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
}
// GetClientIP safely gets the resolved client IP.
func (c *CapturedData) GetClientIP() string {
func (c *CapturedData) GetClientIP() netip.Addr {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
@@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
return data
}
func withServiceId(ctx context.Context, serviceId string) context.Context {
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) string {
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(string)
serviceId, ok := v.(types.ServiceID)
if !ok {
return ""
}

View File

@@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
func BenchmarkServeHTTP(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
ID: types.ServiceID(rand.Text()),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: map[string]*proxy.PathTarget{
@@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
target = id
}
rp.AddMapping(proxy.Mapping{
ID: id,
ID: types.ServiceID(id),
AccountID: types.AccountID(rand.Text()),
Host: host,
Paths: map[string]*proxy.PathTarget{
@@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
}
}
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
ID: types.ServiceID(rand.Text()),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: paths,

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
)
@@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx = roundtrip.WithSkipTLSVerify(ctx)
}
if pt.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
defer cancel()
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
}
rewriteMatchedPath := result.matchedPath
@@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Header.Set(k, v)
}
clientIP := extractClientIP(r.In.RemoteAddr)
clientIP := extractHostIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) {
if isTrustedAddr(clientIP, p.trustedProxies) {
p.setTrustedForwardingHeaders(r, clientIP)
} else {
p.setUntrustedForwardingHeaders(r, clientIP)
@@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string {
// setTrustedForwardingHeaders appends to the existing forwarding header chain
// and preserves upstream-provided headers when the direct connection is from
// a trusted proxy.
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
ipStr := clientIP.String()
// Append the direct connection IP to the existing X-Forwarded-For chain.
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr)
} else {
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Forwarded-For", ipStr)
}
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
@@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
r.Out.Header.Set("X-Real-IP", realIP)
} else {
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
r.Out.Header.Set("X-Real-IP", resolved)
r.Out.Header.Set("X-Real-IP", resolved.String())
}
// Preserve upstream X-Forwarded-Host if present.
@@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
// sets them fresh based on the direct connection. This is the default
// behavior when no trusted proxies are configured or the direct connection
// is from an untrusted source.
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
ipStr := clientIP.String()
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-For", ipStr)
r.Out.Header.Set("X-Real-IP", ipStr)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
@@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) {
}
}
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
// which is always in host:port format.
func extractClientIP(remoteAddr string) string {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return ip
}
// extractForwardedPort returns the port from the Host header if present,
// otherwise defaults to the standard port for the resolved protocol.
func extractForwardedPort(host, resolvedProto string) string {
@@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
web.ServeErrorPage(w, r, code, title, message, requestID, status)
}
// getClientIP retrieves the resolved client IP from context.
// getClientIP retrieves the resolved client IP string from context.
func getClientIP(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetClientIP()
if ip := capturedData.GetClientIP(); ip.IsValid() {
return ip.String()
}
}
return ""
}

View File

@@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
})
}
func TestExtractClientIP(t *testing.T) {
func TestExtractHostIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
expected netip.Addr
}{
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"IPv6 with port", "[::1]:12345", "::1"},
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
{"IPv6 without brackets fallback", "::1", "::1"},
{"empty string fallback", "", ""},
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
{"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")},
{"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")},
{"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")},
{"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")},
{"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")},
{"empty string fallback", "", netip.Addr{}},
{"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr))
})
}
}

View File

@@ -30,8 +30,9 @@ type PathTarget struct {
CustomHeaders map[string]string
}
// Mapping describes how a domain is routed by the HTTP reverse proxy.
type Mapping struct {
ID string
ID types.ServiceID
AccountID types.AccountID
Host string
Paths map[string]*PathTarget
@@ -42,7 +43,7 @@ type Mapping struct {
type targetResult struct {
target *PathTarget
matchedPath string
serviceID string
serviceID types.ServiceID
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
@@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) {
p.mappings[m.Host] = m
}
func (p *ReverseProxy) RemoveMapping(m Mapping) {
// RemoveMapping removes the mapping for the given host and reports whether it existed.
func (p *ReverseProxy) RemoveMapping(m Mapping) bool {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
if _, ok := p.mappings[m.Host]; !ok {
return false
}
delete(p.mappings, m.Host)
return true
}

View File

@@ -7,21 +7,11 @@ import (
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
if len(trusted) == 0 {
return false
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
if err != nil || len(trusted) == 0 {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
return isTrustedAddr(addr.Unmap(), trusted)
}
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
@@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
//
// If the trusted list is empty or remoteAddr is not trusted, it returns the
// remoteAddr IP directly (ignoring any forwarding headers).
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
remoteIP := extractClientIP(remoteAddr)
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr {
remoteIP := extractHostIP(remoteAddr)
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) {
return remoteIP
}
@@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
if ip == "" {
continue
}
if !IsTrustedProxy(ip, trusted) {
return ip
addr, err := netip.ParseAddr(ip)
if err != nil {
continue
}
addr = addr.Unmap()
if !isTrustedAddr(addr, trusted) {
return addr
}
}
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
if addr, err := netip.ParseAddr(first); err == nil {
return addr.Unmap()
}
}
return remoteIP
}
// extractHostIP parses the IP from a host:port string and returns it unmapped.
func extractHostIP(hostPort string) netip.Addr {
if ap, err := netip.ParseAddrPort(hostPort); err == nil {
return ap.Addr().Unmap()
}
if addr, err := netip.ParseAddr(hostPort); err == nil {
return addr.Unmap()
}
return netip.Addr{}
}
// isTrustedAddr checks if the given address falls within any of the trusted prefixes.
func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool {
if !addr.IsValid() {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}

View File

@@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) {
remoteAddr string
xff string
trusted []netip.Prefix
want string
want netip.Addr
}{
{
name: "empty trusted list returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4",
trusted: nil,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4, 10.0.0.1",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
xff: "",
trusted: trusted,
want: "10.0.0.1",
want: netip.MustParseAddr("10.0.0.1"),
},
{
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trusted: trusted,
want: "10.0.0.2",
want: netip.MustParseAddr("10.0.0.2"),
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xff: " 203.0.113.50 , 10.0.0.2 ",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "XFF with empty segments",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50,,10.0.0.2",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "RemoteAddr without port",
remoteAddr: "10.0.0.1",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
}
for _, tt := range tests {

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -14,11 +15,12 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/embed"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -26,7 +28,22 @@ import (
const deviceNamePrefix = "ingress-proxy-"
// backendKey identifies a backend by its host:port from the target URL.
type backendKey = string
type backendKey string
// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service)
// that holds a reference to an embedded NetBird client. Callers should use the
// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions.
type ServiceKey string
// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service.
func DomainServiceKey(domain string) ServiceKey {
return ServiceKey("domain:" + domain)
}
// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP).
func L4ServiceKey(id types.ServiceID) ServiceKey {
return ServiceKey("l4:" + id)
}
var (
// ErrNoAccountID is returned when a request context is missing the account ID.
@@ -39,24 +56,24 @@ var (
ErrTooManyInflight = errors.New("too many in-flight requests")
)
// domainInfo holds metadata about a registered domain.
type domainInfo struct {
serviceID string
// serviceInfo holds metadata about a registered service.
type serviceInfo struct {
serviceID types.ServiceID
}
type domainNotification struct {
domain domain.Domain
serviceID string
type serviceNotification struct {
key ServiceKey
serviceID types.ServiceID
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
// clientEntry holds an embedded NetBird client and tracks which services use it.
type clientEntry struct {
client *embed.Client
transport *http.Transport
// insecureTransport is a clone of transport with TLS verification disabled,
// used when per-target skip_tls_verify is set.
insecureTransport *http.Transport
domains map[domain.Domain]domainInfo
services map[ServiceKey]serviceInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
@@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
// ClientConfig holds configuration for the embedded NetBird client.
type ClientConfig struct {
MgmtAddr string
WGPort int
WGPort uint16
PreSharedKey string
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error
}
type managementClient interface {
@@ -107,7 +124,7 @@ type managementClient interface {
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
type NetBird struct {
proxyID string
proxyAddr string
@@ -124,11 +141,11 @@ type NetBird struct {
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
AccountID types.AccountID
ServiceCount int
ServiceKeys []string
HasClient bool
CreatedAt time.Time
}
// accountIDContextKey is the context key for storing the account ID.
@@ -137,37 +154,37 @@ type accountIDContextKey struct{}
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
type skipTLSVerifyContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// AddPeer registers a service for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
// Multiple services can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
// Client already exists for this account, just register the domain
entry.domains[d] = domainInfo{serviceID: serviceID}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("registered domain with existing client")
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
// If client is already started, notify this domain as connected immediately
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if err != nil {
n.clientsMux.Unlock()
return err
@@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"account_id": accountID,
"service_key": key,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
@@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
@@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
WireguardPublicKey: publicKey.String(),
@@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &n.clientCfg.WGPort,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
if err != nil {
@@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
transport := &http.Transport{
DialContext: client.DialContext,
DialContext: dialWithTimeout(client.DialContext),
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
@@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
services: map[ServiceKey]serviceInfo{key: si},
transport: transport,
insecureTransport: insecureTransport,
createdAt: time.Now(),
@@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
// runClientStartup starts the client and notifies registered services on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
// Mark client as started and collect domains to notify outside the lock.
// Mark client as started and collect services to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
var toNotify []serviceNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
for key, info := range entry.services {
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
@@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
for _, sn := range toNotify {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
"account_id": accountID,
"service_key": sn.key,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
"account_id": accountID,
"service_key": sn.key,
}).Info("notified management about tunnel connection")
}
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped
// when no domains are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
return nil
}
// Get domain info before deleting
domInfo, domainExists := entry.domains[d]
if !domainExists {
si, svcExists := entry.services[key]
if !svcExists {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("remove peer: domain not registered")
"account_id": accountID,
"service_key": key,
}).Debug("remove peer: service not registered")
return nil
}
delete(entry.domains, d)
// If there are still domains using this client, keep it running
if len(entry.domains) > 0 {
n.clientsMux.Unlock()
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"remaining_domains": len(entry.domains),
}).Debug("unregistered domain, client still in use")
// Notify this domain as disconnected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
return nil
"account_id": accountID,
"service_key": key,
"remaining_services": len(entry.services),
}).Debug("unregistered service, client still in use")
}
// No more domains using this client, stop it
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Info("stopping client, no more domains")
client := entry.client
transport := entry.transport
insecureTransport := entry.insecureTransport
delete(n.clients, accountID)
n.clientsMux.Unlock()
// Notify disconnection before stopping
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client")
}
return nil
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return
}
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil {
if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound {
n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
}
// RoundTrip implements http.RoundTripper. It looks up the client for the account
// specified in the request context and uses it to dial the backend.
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
}
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
release, ok := entry.acquireInflight(backendKey(req.URL.Host))
defer release()
if !ok {
return nil, ErrTooManyInflight
@@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool {
return exists
}
// DomainCount returns the number of domains registered for the given account.
// ServiceCount returns the number of services registered for the given account.
// Returns 0 if the account has no client.
func (n *NetBird) DomainCount(accountID types.AccountID) int {
func (n *NetBird) ServiceCount(accountID types.AccountID) int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return 0
}
return len(entry.domains)
return len(entry.services)
}
// ClientCount returns the total number of active clients.
@@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
result := make(map[types.AccountID]ClientDebugInfo)
for accountID, entry := range n.clients {
domains := make(domain.List, 0, len(entry.domains))
for d := range entry.domains {
domains = append(domains, d)
keys := make([]string, 0, len(entry.services))
for k := range entry.services {
keys = append(keys, string(k))
}
result[accountID] = ClientDebugInfo{
AccountID: accountID,
DomainCount: len(entry.domains),
Domains: domains,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
AccountID: accountID,
ServiceCount: len(entry.services),
ServiceKeys: keys,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
}
}
return result
@@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
}
}
// dialWithTimeout wraps a DialContext function so that any dial timeout
// stored in the context (via types.WithDialTimeout) is applied only to
// the connection establishment phase, not the full request lifetime.
func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if d, ok := types.DialTimeoutFromContext(ctx); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d)
defer cancel()
}
return dial(ctx, network, addr)
}
}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
return context.WithValue(ctx, accountIDContextKey{}, accountID)

View File

@@ -1,6 +1,7 @@
package roundtrip
import (
"context"
"crypto/rand"
"math/big"
"sync"
@@ -8,7 +9,6 @@ import (
"time"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
// Simple benchmark for comparison with AddPeer contention.
@@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
services: map[ServiceKey]serviceInfo{
ServiceKey(rand.Text()): {
serviceID: types.ServiceID(rand.Text()),
},
},
createdAt: time.Now(),
@@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
services: map[ServiceKey]serviceInfo{
ServiceKey(rand.Text()): {
serviceID: types.ServiceID(rand.Text()),
},
},
createdAt: time.Now(),
@@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
}
// Launch workers that continuously call AddPeer with new random accountIDs.
ctx, cancel := context.WithCancel(b.Context())
var wg sync.WaitGroup
for range addPeerWorkers {
wg.Go(func() {
for {
if err := nb.AddPeer(b.Context(),
wg.Add(1)
go func() {
defer wg.Done()
for ctx.Err() == nil {
if err := nb.AddPeer(ctx,
types.AccountID(rand.Text()),
domain.Domain(rand.Text()),
ServiceKey(rand.Text()),
rand.Text(),
rand.Text()); err != nil {
b.Log(err)
types.ServiceID(rand.Text())); err != nil {
return
}
}
})
}()
}
// Benchmark calling HasClient during AddPeer contention.
@@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
}
})
b.StopTimer()
cancel()
wg.Wait()
}

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -27,16 +26,15 @@ type mockStatusNotifier struct {
}
type statusCall struct {
accountID string
serviceID string
domain string
accountID types.AccountID
serviceID types.ServiceID
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
return nil
}
@@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
// Initially no client exists.
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
// Add first domain - this should create a new client.
// Note: This will fail to actually connect since we use an invalid URL,
// but the client entry should still be created.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add first service - this should create a new client.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
}
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add first domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add first service.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.Equal(t, 1, nb.DomainCount(accountID))
assert.Equal(t, 1, nb.ServiceCount(accountID))
// Add second domain for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
// Add second service for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
require.NoError(t, err)
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service")
// Add third domain.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
// Add third service.
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service")
// Still only one client.
assert.True(t, nb.HasClient(accountID))
@@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
// Add domain for account 1.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add service for account 1.
err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
// Add domain for account 2.
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
// Add service for account 2.
err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2"))
require.NoError(t, err)
// Both accounts should have their own clients.
assert.True(t, nb.HasClient(account1), "account1 should have client")
assert.True(t, nb.HasClient(account2), "account2 should have client")
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1")
assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1")
}
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add multiple domains.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add multiple services.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID))
assert.Equal(t, 3, nb.ServiceCount(accountID))
// Remove one domain - client should remain.
// Remove one service - client should remain.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
assert.True(t, nb.HasClient(accountID), "client should remain after removing one service")
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2")
// Remove another domain - client should still remain.
// Remove another service - client should still remain.
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
assert.True(t, nb.HasClient(accountID), "client should remain after removing second service")
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
}
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add single domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add single service.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
// Remove the only domain - client should be removed.
// Note: Stop() may fail since the client never actually connected,
// but the entry should still be removed from the map.
// Remove the only service - client should be removed.
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
// After removing all domains, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
// After removing all services, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service")
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
}
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
@@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
assert.NoError(t, err, "removing from non-existent account should not error")
}
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add one domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add one service.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
// Remove non-existent domain - should not affect existing domain.
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
// Remove non-existent service - should not affect existing service.
err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test")
require.NoError(t, err)
// Original domain should still be registered.
// Original service should still be registered.
assert.True(t, nb.HasClient(accountID))
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain")
}
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
@@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
account2 := types.AccountID("account-2")
account3 := types.AccountID("account-3")
// Add domains for multiple accounts.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
// Add services for multiple accounts.
err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3"))
require.NoError(t, err)
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
// Stop all clients.
// Note: StopAll may return errors since clients never actually connected,
// but the clients should still be removed from the map.
_ = nb.StopAll(context.Background())
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
@@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) {
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
// Add clients for different accounts.
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.Equal(t, 1, nb.ClientCount())
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2"))
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount())
// Adding domain to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
// Adding service to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b"))
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count")
}
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
@@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
// Add first service — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
@@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
// Add second service — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.Equal(t, accountID, calls[0].accountID)
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
assert.True(t, calls[0].connected)
}
@@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
// Remove one service — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
assert.False(t, calls[0].connected)
}

View File

@@ -0,0 +1,133 @@
package tcp
import (
"bytes"
"crypto/tls"
"io"
"net"
"testing"
)
// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real
// TLS ClientHello and extracting the SNI. This is the per-connection cost
// added to every TLS connection on the main listener.
func BenchmarkPeekClientHello_TLS(b *testing.B) {
// Pre-generate a ClientHello by capturing what crypto/tls sends.
clientConn, serverConn := net.Pipe()
go func() {
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: "app.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
}()
var hello []byte
buf := make([]byte, 16384)
n, _ := serverConn.Read(buf)
hello = make([]byte, n)
copy(hello, buf[:n])
clientConn.Close()
serverConn.Close()
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
r := bytes.NewReader(hello)
conn := &readerConn{Reader: r}
sni, wrapped, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}
if sni != "app.example.com" {
b.Fatalf("unexpected SNI: %q", sni)
}
// Simulate draining the peeked bytes (what the HTTP server would do).
_, _ = io.Copy(io.Discard, wrapped)
}
}
// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS
// connections that hit the fast non-handshake exit path.
func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
r := bytes.NewReader(httpReq)
conn := &readerConn{Reader: r}
_, wrapped, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}
_, _ = io.Copy(io.Discard, wrapped)
}
}
// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn
// wrapper compared to a plain connection read. The peeked bytes use
// io.MultiReader which adds one indirection per Read call.
func BenchmarkPeekedConn_Read(b *testing.B) {
data := make([]byte, 4096)
peeked := make([]byte, 512)
buf := make([]byte, 1024)
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
r := bytes.NewReader(data)
conn := &readerConn{Reader: r}
pc := newPeekedConn(conn, peeked)
for {
_, err := pc.Read(buf)
if err != nil {
break
}
}
}
}
// BenchmarkExtractSNI measures just the in-memory SNI parsing cost,
// excluding I/O.
func BenchmarkExtractSNI(b *testing.B) {
clientConn, serverConn := net.Pipe()
go func() {
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: "app.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
}()
buf := make([]byte, 16384)
n, _ := serverConn.Read(buf)
payload := make([]byte, n-tlsRecordHeaderLen)
copy(payload, buf[tlsRecordHeaderLen:n])
clientConn.Close()
serverConn.Close()
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
sni := extractSNI(payload)
if sni != "app.example.com" {
b.Fatalf("unexpected SNI: %q", sni)
}
}
}
// readerConn wraps an io.Reader as a net.Conn for benchmarking.
// Only Read is functional; all other methods are no-ops.
type readerConn struct {
io.Reader
net.Conn
}
func (c *readerConn) Read(b []byte) (int, error) {
return c.Reader.Read(b)
}

View File

@@ -0,0 +1,76 @@
package tcp
import (
"net"
"sync"
)
// chanListener implements net.Listener by reading connections from a channel.
// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS.
type chanListener struct {
ch chan net.Conn
addr net.Addr
once sync.Once
closed chan struct{}
}
func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener {
return &chanListener{
ch: ch,
addr: addr,
closed: make(chan struct{}),
}
}
// Accept waits for and returns the next connection from the channel.
func (l *chanListener) Accept() (net.Conn, error) {
for {
select {
case conn, ok := <-l.ch:
if !ok {
return nil, net.ErrClosed
}
return conn, nil
case <-l.closed:
// Drain buffered connections before returning.
for {
select {
case conn, ok := <-l.ch:
if !ok {
return nil, net.ErrClosed
}
_ = conn.Close()
default:
return nil, net.ErrClosed
}
}
}
}
}
// Close signals the listener to stop accepting connections and drains
// any buffered connections that have not yet been accepted.
func (l *chanListener) Close() error {
l.once.Do(func() {
close(l.closed)
for {
select {
case conn, ok := <-l.ch:
if !ok {
return
}
_ = conn.Close()
default:
return
}
}
})
return nil
}
// Addr returns the listener's network address.
func (l *chanListener) Addr() net.Addr {
return l.addr
}
var _ net.Listener = (*chanListener)(nil)

View File

@@ -0,0 +1,39 @@
package tcp
import (
"bytes"
"io"
"net"
)
// peekedConn wraps a net.Conn and prepends previously peeked bytes
// so that readers see the full original stream transparently.
type peekedConn struct {
net.Conn
reader io.Reader
}
func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn {
return &peekedConn{
Conn: conn,
reader: io.MultiReader(bytes.NewReader(peeked), conn),
}
}
// Read replays the peeked bytes first, then reads from the underlying conn.
func (c *peekedConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// CloseWrite delegates to the underlying connection if it supports
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
// as an interface hides the concrete type's CloseWrite method, making
// half-close a silent no-op for all SNI-routed connections.
func (c *peekedConn) CloseWrite() error {
if hc, ok := c.Conn.(halfCloser); ok {
return hc.CloseWrite()
}
return nil
}
var _ halfCloser = (*peekedConn)(nil)

View File

@@ -0,0 +1,29 @@
package tcp
import (
"fmt"
"net"
"github.com/pires/go-proxyproto"
)
// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection,
// conveying the real client address.
func writeProxyProtoV2(client, backend net.Conn) error {
tp := proxyproto.TCPv4
if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil {
tp = proxyproto.TCPv6
}
header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: tp,
SourceAddr: client.RemoteAddr(),
DestinationAddr: client.LocalAddr(),
}
if _, err := header.WriteTo(backend); err != nil {
return fmt.Errorf("write PROXY protocol v2 header: %w", err)
}
return nil
}

View File

@@ -0,0 +1,128 @@
package tcp
import (
"bufio"
"net"
"testing"
"github.com/pires/go-proxyproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWriteProxyProtoV2_IPv4(t *testing.T) {
// Set up a real TCP listener and dial to get connections with real addresses.
ln, err := net.Listen("tcp4", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
var serverConn net.Conn
accepted := make(chan struct{})
go func() {
var err error
serverConn, err = ln.Accept()
if err != nil {
t.Error("accept failed:", err)
}
close(accepted)
}()
clientConn, err := net.Dial("tcp4", ln.Addr().String())
require.NoError(t, err)
defer clientConn.Close()
<-accepted
defer serverConn.Close()
// Use a pipe as the backend: write the header to one end, read from the other.
backendRead, backendWrite := net.Pipe()
defer backendRead.Close()
defer backendWrite.Close()
// serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination.
writeDone := make(chan error, 1)
go func() {
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
}()
// Read the PROXY protocol header from the backend read side.
header, err := proxyproto.Read(bufio.NewReader(backendRead))
require.NoError(t, err)
require.NotNil(t, header, "should have received a proxy protocol header")
writeErr := <-writeDone
require.NoError(t, writeErr)
assert.Equal(t, byte(2), header.Version, "version should be 2")
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4")
// serverConn.RemoteAddr() is the client's address (source in the header).
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
actualSrc := header.SourceAddr.(*net.TCPAddr)
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
// serverConn.LocalAddr() is the server's address (destination in the header).
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
actualDst := header.DestinationAddr.(*net.TCPAddr)
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
}
func TestWriteProxyProtoV2_IPv6(t *testing.T) {
// Set up a real TCP6 listener on loopback.
ln, err := net.Listen("tcp6", "[::1]:0")
if err != nil {
t.Skip("IPv6 not available:", err)
}
defer ln.Close()
var serverConn net.Conn
accepted := make(chan struct{})
go func() {
var err error
serverConn, err = ln.Accept()
if err != nil {
t.Error("accept failed:", err)
}
close(accepted)
}()
clientConn, err := net.Dial("tcp6", ln.Addr().String())
require.NoError(t, err)
defer clientConn.Close()
<-accepted
defer serverConn.Close()
backendRead, backendWrite := net.Pipe()
defer backendRead.Close()
defer backendWrite.Close()
writeDone := make(chan error, 1)
go func() {
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
}()
header, err := proxyproto.Read(bufio.NewReader(backendRead))
require.NoError(t, err)
require.NotNil(t, header, "should have received a proxy protocol header")
writeErr := <-writeDone
require.NoError(t, writeErr)
assert.Equal(t, byte(2), header.Version, "version should be 2")
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6")
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
actualSrc := header.SourceAddr.(*net.TCPAddr)
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
actualDst := header.DestinationAddr.(*net.TCPAddr)
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
}

156
proxy/internal/tcp/relay.go Normal file
View File

@@ -0,0 +1,156 @@
package tcp
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/netutil"
)
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
var errIdleTimeout = errors.New("idle timeout")
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
// A zero value disables idle timeout checking.
const DefaultIdleTimeout = 5 * time.Minute
// halfCloser is implemented by connections that support half-close
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
// EOF to the remote by closing the write side while keeping the read
// side open so the other direction can drain.
type halfCloser interface {
CloseWrite() error
}
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
var copyBufPool = sync.Pool{
New: func() any {
buf := make([]byte, 32*1024)
return &buf
},
}
// Relay copies data bidirectionally between src and dst until both
// sides are done or the context is canceled. When idleTimeout is
// non-zero, each direction's read is deadline-guarded; if no data
// flows within the timeout the connection is torn down. When one
// direction finishes, it half-closes the write side of the
// destination (if supported) to signal EOF, allowing the other
// direction to drain gracefully before the full connection teardown.
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = src.Close()
_ = dst.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var errSrcToDst, errDstToSrc error
go func() {
defer wg.Done()
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
halfClose(dst)
cancel()
}()
go func() {
defer wg.Done()
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
halfClose(src)
cancel()
}()
wg.Wait()
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
logger.Debug("relay closed due to idle timeout")
}
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
}
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
}
return srcToDst, dstToSrc
}
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
// When idleTimeout > 0 it sets a read deadline on src before each
// read and treats a timeout as an idle-triggered close.
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
bufp := copyBufPool.Get().(*[]byte)
defer copyBufPool.Put(bufp)
if idleTimeout <= 0 {
return io.CopyBuffer(dst, src, *bufp)
}
conn, ok := src.(net.Conn)
if !ok {
return io.CopyBuffer(dst, src, *bufp)
}
buf := *bufp
var total int64
for {
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return total, err
}
nr, readErr := src.Read(buf)
if nr > 0 {
n, err := checkedWrite(dst, buf[:nr])
total += n
if err != nil {
return total, err
}
}
if readErr != nil {
if netutil.IsTimeout(readErr) {
return total, errIdleTimeout
}
return total, readErr
}
}
}
// checkedWrite writes buf to dst and returns the number of bytes written.
// It guards against short writes and negative counts per io.Copy convention.
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
nw, err := dst.Write(buf)
if nw < 0 || nw > len(buf) {
nw = 0
}
if err != nil {
return int64(nw), err
}
if nw != len(buf) {
return int64(nw), io.ErrShortWrite
}
return int64(nw), nil
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
}
// halfClose attempts to half-close the write side of the connection.
// If the connection does not support half-close, this is a no-op.
func halfClose(conn net.Conn) {
if hc, ok := conn.(halfCloser); ok {
// Best-effort; the full close will follow shortly.
_ = hc.CloseWrite()
}
}

View File

@@ -0,0 +1,210 @@
package tcp
import (
"context"
"fmt"
"io"
"net"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/netutil"
)
func TestRelay_BidirectionalCopy(t *testing.T) {
srcClient, srcServer := net.Pipe()
dstClient, dstServer := net.Pipe()
logger := log.NewEntry(log.StandardLogger())
ctx := context.Background()
srcData := []byte("hello from src")
dstData := []byte("hello from dst")
// dst side: write response first, then read + close.
go func() {
_, _ = dstClient.Write(dstData)
buf := make([]byte, 256)
_, _ = dstClient.Read(buf)
dstClient.Close()
}()
// src side: read the response, then send data + close.
go func() {
buf := make([]byte, 256)
_, _ = srcClient.Read(buf)
_, _ = srcClient.Write(srcData)
srcClient.Close()
}()
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0)
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
}
func TestRelay_ContextCancellation(t *testing.T) {
srcClient, srcServer := net.Pipe()
dstClient, dstServer := net.Pipe()
defer srcClient.Close()
defer dstClient.Close()
logger := log.NewEntry(log.StandardLogger())
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Relay(ctx, logger, srcServer, dstServer, 0)
close(done)
}()
// Cancel should cause Relay to return.
cancel()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Relay did not return after context cancellation")
}
}
func TestRelay_OneSideClosed(t *testing.T) {
srcClient, srcServer := net.Pipe()
dstClient, dstServer := net.Pipe()
defer dstClient.Close()
logger := log.NewEntry(log.StandardLogger())
ctx := context.Background()
// Close src immediately. Relay should complete without hanging.
srcClient.Close()
done := make(chan struct{})
go func() {
Relay(ctx, logger, srcServer, dstServer, 0)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Relay did not return after one side closed")
}
}
func TestRelay_LargeTransfer(t *testing.T) {
srcClient, srcServer := net.Pipe()
dstClient, dstServer := net.Pipe()
logger := log.NewEntry(log.StandardLogger())
ctx := context.Background()
// 1MB of data.
data := make([]byte, 1<<20)
for i := range data {
data[i] = byte(i % 256)
}
go func() {
_, _ = srcClient.Write(data)
srcClient.Close()
}()
errCh := make(chan error, 1)
go func() {
received, err := io.ReadAll(dstClient)
if err != nil {
errCh <- err
return
}
if len(received) != len(data) {
errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received))
return
}
errCh <- nil
dstClient.Close()
}()
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0)
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
require.NoError(t, <-errCh)
}
func TestRelay_IdleTimeout(t *testing.T) {
// Use real TCP connections so SetReadDeadline works (net.Pipe
// does not support deadlines).
srcLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer srcLn.Close()
dstLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer dstLn.Close()
srcClient, err := net.Dial("tcp", srcLn.Addr().String())
if err != nil {
t.Fatal(err)
}
defer srcClient.Close()
srcServer, err := srcLn.Accept()
if err != nil {
t.Fatal(err)
}
dstClient, err := net.Dial("tcp", dstLn.Addr().String())
if err != nil {
t.Fatal(err)
}
defer dstClient.Close()
dstServer, err := dstLn.Accept()
if err != nil {
t.Fatal(err)
}
logger := log.NewEntry(log.StandardLogger())
ctx := context.Background()
// Send initial data to prove the relay works.
go func() {
_, _ = srcClient.Write([]byte("ping"))
}()
done := make(chan struct{})
var s2d, d2s int64
go func() {
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
close(done)
}()
// Read the forwarded data on the dst side.
buf := make([]byte, 64)
n, err := dstClient.Read(buf)
assert.NoError(t, err)
assert.Equal(t, "ping", string(buf[:n]))
// Now stop sending. The relay should close after the idle timeout.
select {
case <-done:
assert.Greater(t, s2d, int64(0), "should have transferred initial data")
_ = d2s
case <-time.After(5 * time.Second):
t.Fatal("Relay did not exit after idle timeout")
}
}
func TestIsExpectedError(t *testing.T) {
assert.True(t, netutil.IsExpectedError(net.ErrClosed))
assert.True(t, netutil.IsExpectedError(context.Canceled))
assert.True(t, netutil.IsExpectedError(io.EOF))
assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF))
}

View File

@@ -0,0 +1,570 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// defaultDialTimeout is the fallback dial timeout when no per-route
// timeout is configured.
const defaultDialTimeout = 30 * time.Second
// SNIHost is a typed key for SNI hostname lookups.
type SNIHost string
// RouteType specifies how a connection should be handled.
type RouteType int
const (
// RouteHTTP routes the connection through the HTTP reverse proxy.
RouteHTTP RouteType = iota
// RouteTCP relays the connection directly to the backend (TLS passthrough).
RouteTCP
)
const (
// sniPeekTimeout is the deadline for reading the TLS ClientHello.
sniPeekTimeout = 5 * time.Second
// DefaultDrainTimeout is the default grace period for in-flight relay
// connections to finish during shutdown.
DefaultDrainTimeout = 30 * time.Second
// DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router.
DefaultMaxRelayConns = 4096
// httpChannelBuffer is the capacity of the channel feeding HTTP connections.
httpChannelBuffer = 4096
)
// DialResolver returns a DialContextFunc for the given account.
type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error)
// Route describes where a connection for a given SNI should be sent.
type Route struct {
Type RouteType
AccountID types.AccountID
ServiceID types.ServiceID
// Domain is the service's configured domain, used for access log entries.
Domain string
// Protocol is the frontend protocol (tcp, tls), used for access log entries.
Protocol accesslog.Protocol
// Target is the backend address for TCP relay (e.g. "10.0.0.5:5432").
Target string
// ProxyProtocol enables sending a PROXY protocol v2 header to the backend.
ProxyProtocol bool
// DialTimeout overrides the default dial timeout for this route.
// Zero uses defaultDialTimeout.
DialTimeout time.Duration
}
// l4Logger sends layer-4 access log entries to the management server.
type l4Logger interface {
LogL4(entry accesslog.L4Entry)
}
// RelayObserver receives callbacks for TCP relay lifecycle events.
// All methods must be safe for concurrent use.
type RelayObserver interface {
TCPRelayStarted(accountID types.AccountID)
TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64)
TCPRelayDialError(accountID types.AccountID)
TCPRelayRejected(accountID types.AccountID)
}
// Router accepts raw TCP connections on a shared listener, peeks at
// the TLS ClientHello to extract the SNI, and routes the connection
// to either the HTTP reverse proxy or a direct TCP relay.
type Router struct {
logger *log.Logger
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
httpCh chan net.Conn
httpListener *chanListener
mu sync.RWMutex
routes map[SNIHost][]Route
fallback *Route
draining bool
dialResolve DialResolver
activeConns sync.WaitGroup
activeRelays sync.WaitGroup
relaySem chan struct{}
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
// svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context
svcCancels map[types.ServiceID]context.CancelFunc
}
// NewRouter creates a new SNI-based connection router.
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
httpCh := make(chan net.Conn, httpChannelBuffer)
return &Router{
logger: logger,
httpCh: httpCh,
httpListener: newChanListener(httpCh, addr),
routes: make(map[SNIHost][]Route),
dialResolve: dialResolve,
relaySem: make(chan struct{}, DefaultMaxRelayConns),
svcCtxs: make(map[types.ServiceID]context.Context),
svcCancels: make(map[types.ServiceID]context.CancelFunc),
}
}
// NewPortRouter creates a Router for a dedicated port without an HTTP
// channel. Connections that don't match any SNI route fall through to
// the fallback relay (if set) or are closed.
func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router {
return &Router{
logger: logger,
routes: make(map[SNIHost][]Route),
dialResolve: dialResolve,
relaySem: make(chan struct{}, DefaultMaxRelayConns),
svcCtxs: make(map[types.ServiceID]context.Context),
svcCancels: make(map[types.ServiceID]context.CancelFunc),
}
}
// HTTPListener returns a net.Listener that yields connections routed
// to the HTTP handler. Use this with http.Server.ServeTLS.
func (r *Router) HTTPListener() net.Listener {
return r.httpListener
}
// AddRoute registers an SNI route. Multiple routes for the same host are
// stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
func (r *Router) AddRoute(host SNIHost, route Route) {
if host == "" {
return
}
r.mu.Lock()
defer r.mu.Unlock()
routes := r.routes[host]
for i, existing := range routes {
if existing.ServiceID == route.ServiceID {
r.cancelServiceLocked(route.ServiceID)
routes[i] = route
return
}
}
r.routes[host] = append(routes, route)
}
// RemoveRoute removes the route for the given host and service ID.
// Active relay connections for the service are closed immediately.
// If other routes remain for the host, they are preserved.
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
r.mu.Lock()
defer r.mu.Unlock()
r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool {
return route.ServiceID == svcID
})
if len(r.routes[host]) == 0 {
delete(r.routes, host)
}
r.cancelServiceLocked(svcID)
}
// SetFallback registers a catch-all route for connections that don't
// match any SNI route. On a port router this handles plain TCP relay;
// on the main router it takes priority over the HTTP channel.
func (r *Router) SetFallback(route Route) {
r.mu.Lock()
defer r.mu.Unlock()
r.fallback = &route
}
// RemoveFallback clears the catch-all fallback route and closes any
// active relay connections for the given service.
func (r *Router) RemoveFallback(svcID types.ServiceID) {
r.mu.Lock()
defer r.mu.Unlock()
r.fallback = nil
r.cancelServiceLocked(svcID)
}
// SetObserver sets the relay lifecycle observer. Must be called before Serve.
func (r *Router) SetObserver(obs RelayObserver) {
r.mu.Lock()
defer r.mu.Unlock()
r.observer = obs
}
// SetAccessLogger sets the L4 access logger. Must be called before Serve.
func (r *Router) SetAccessLogger(l l4Logger) {
r.mu.Lock()
defer r.mu.Unlock()
r.accessLog = l
}
// getObserver returns the current relay observer under the read lock.
func (r *Router) getObserver() RelayObserver {
r.mu.RLock()
defer r.mu.RUnlock()
return r.observer
}
// IsEmpty returns true when the router has no SNI routes and no fallback.
func (r *Router) IsEmpty() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.routes) == 0 && r.fallback == nil
}
// Serve accepts connections from ln and routes them based on SNI.
// It blocks until ctx is canceled or ln is closed, then drains
// active relay connections up to DefaultDrainTimeout.
func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
_ = ln.Close()
if r.httpListener != nil {
r.httpListener.Close()
}
case <-done:
}
}()
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ok := r.Drain(DefaultDrainTimeout); !ok {
r.logger.Warn("timed out waiting for connections to drain")
}
return nil
}
r.logger.Debugf("SNI router accept: %v", err)
continue
}
r.activeConns.Add(1)
go func() {
defer r.activeConns.Done()
r.handleConn(ctx, conn)
}()
}
}
// handleConn peeks at the TLS ClientHello and routes the connection.
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
// fallback port), skip the TLS peek entirely to avoid read errors on
// non-TLS connections and reduce latency.
if r.isFallbackOnly() {
r.handleUnmatched(ctx, conn)
return
}
if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil {
r.logger.Debugf("set SNI peek deadline: %v", err)
_ = conn.Close()
return
}
sni, wrapped, err := PeekClientHello(conn)
if err != nil {
r.logger.Debugf("SNI peek: %v", err)
if wrapped != nil {
r.handleUnmatched(ctx, wrapped)
} else {
_ = conn.Close()
}
return
}
if err := wrapped.SetReadDeadline(time.Time{}); err != nil {
r.logger.Debugf("clear SNI peek deadline: %v", err)
_ = wrapped.Close()
return
}
host := SNIHost(sni)
route, ok := r.lookupRoute(host)
if !ok {
r.handleUnmatched(ctx, wrapped)
return
}
if route.Type == RouteHTTP {
r.sendToHTTP(wrapped)
return
}
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
_ = wrapped.Close()
}
}
// isFallbackOnly returns true when the router has no SNI routes and no HTTP
// channel, meaning all connections should go directly to the fallback relay.
func (r *Router) isFallbackOnly() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.routes) == 0 && r.httpCh == nil
}
// handleUnmatched routes a connection that didn't match any SNI route.
// This includes ECH/ESNI connections where the cleartext SNI is empty.
// It tries the fallback relay first, then the HTTP channel, and closes
// the connection if neither is available.
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
r.mu.RLock()
fb := r.fallback
r.mu.RUnlock()
if fb != nil {
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
_ = conn.Close()
}
return
}
r.sendToHTTP(conn)
}
// lookupRoute returns the highest-priority route for the given SNI host.
// HTTP routes take precedence over TCP routes.
func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
routes, ok := r.routes[host]
if !ok || len(routes) == 0 {
return Route{}, false
}
best := routes[0]
for _, route := range routes[1:] {
if route.Type < best.Type {
best = route
}
}
return best, true
}
// sendToHTTP feeds the connection to the HTTP handler via the channel.
// If no HTTP channel is configured (port router), the router is
// draining, or the channel is full, the connection is closed.
func (r *Router) sendToHTTP(conn net.Conn) {
if r.httpCh == nil {
_ = conn.Close()
return
}
r.mu.RLock()
draining := r.draining
r.mu.RUnlock()
if draining {
_ = conn.Close()
return
}
select {
case r.httpCh <- conn:
default:
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
_ = conn.Close()
}
}
// Drain prevents new relay connections from starting and waits for all
// in-flight connection handlers and active relays to finish, up to the
// given timeout. Returns true if all completed, false on timeout.
func (r *Router) Drain(timeout time.Duration) bool {
r.mu.Lock()
r.draining = true
if r.drainDone == nil {
done := make(chan struct{})
go func() {
r.activeConns.Wait()
r.activeRelays.Wait()
close(done)
}()
r.drainDone = done
}
done := r.drainDone
r.mu.Unlock()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
// cancelServiceLocked cancels and removes the context for the given service,
// closing all its active relay connections. Must be called with mu held.
func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
if cancel, ok := r.svcCancels[svcID]; ok {
cancel()
delete(r.svcCtxs, svcID)
delete(r.svcCancels, svcID)
}
}
// relayTCP sets up and runs a bidirectional TCP relay.
// The caller owns conn and must close it if this method returns an error.
// On success (nil error), both conn and backend are closed by the relay.
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
svcCtx, err := r.acquireRelay(ctx, route)
if err != nil {
return err
}
defer func() {
<-r.relaySem
r.activeRelays.Done()
}()
backend, err := r.dialBackend(svcCtx, route)
if err != nil {
obs := r.getObserver()
if obs != nil {
obs.TCPRelayDialError(route.AccountID)
}
return err
}
if route.ProxyProtocol {
if err := writeProxyProtoV2(conn, backend); err != nil {
_ = backend.Close()
return fmt.Errorf("write PROXY protocol header: %w", err)
}
}
obs := r.getObserver()
if obs != nil {
obs.TCPRelayStarted(route.AccountID)
}
entry := r.logger.WithFields(log.Fields{
"sni": sni,
"service_id": route.ServiceID,
"target": route.Target,
})
entry.Debug("TCP relay started")
start := time.Now()
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
elapsed := time.Since(start)
if obs != nil {
obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s)
}
entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s)
r.logL4Entry(route, conn, elapsed, s2d, d2s)
return nil
}
// acquireRelay checks draining state, increments activeRelays, and acquires
// a semaphore slot. Returns the per-service context on success.
// The caller must release the semaphore and call activeRelays.Done() when done.
func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) {
r.mu.Lock()
if r.draining {
r.mu.Unlock()
return nil, errors.New("router is draining")
}
r.activeRelays.Add(1)
svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID)
r.mu.Unlock()
select {
case r.relaySem <- struct{}{}:
return svcCtx, nil
default:
r.activeRelays.Done()
obs := r.getObserver()
if obs != nil {
obs.TCPRelayRejected(route.AccountID)
}
return nil, errors.New("TCP relay connection limit reached")
}
}
// dialBackend resolves the dialer for the route's account and dials the backend.
func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) {
dialFn, err := r.dialResolve(route.AccountID)
if err != nil {
return nil, fmt.Errorf("resolve dialer: %w", err)
}
dialTimeout := route.DialTimeout
if dialTimeout <= 0 {
dialTimeout = defaultDialTimeout
}
dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout)
backend, err := dialFn(dialCtx, "tcp", route.Target)
dialCancel()
if err != nil {
return nil, fmt.Errorf("dial backend %s: %w", route.Target, err)
}
return backend, nil
}
// logL4Entry sends a TCP relay access log entry if an access logger is configured.
func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) {
r.mu.RLock()
al := r.accessLog
r.mu.RUnlock()
if al == nil {
return
}
var sourceIP netip.Addr
if remote := conn.RemoteAddr(); remote != nil {
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
}
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
ServiceID: route.ServiceID,
Protocol: route.Protocol,
Host: route.Domain,
SourceIP: sourceIP,
DurationMs: duration.Milliseconds(),
BytesUpload: bytesUp,
BytesDownload: bytesDown,
})
}
// getOrCreateServiceCtxLocked returns the context for a service, creating one
// if it doesn't exist yet. The context is a child of the server context.
// Must be called with mu held.
func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context {
if ctx, ok := r.svcCtxs[svcID]; ok {
return ctx
}
ctx, cancel := context.WithCancel(parent)
r.svcCtxs[svcID] = ctx
r.svcCancels[svcID] = cancel
return ctx
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,191 @@
package tcp
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
)
const (
// TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2).
tlsRecordHeaderLen = 5
// TLS handshake type for ClientHello.
handshakeTypeClientHello = 1
// TLS ContentType for handshake messages.
contentTypeHandshake = 22
// SNI extension type (RFC 6066).
extensionServerName = 0
// SNI host name type.
sniHostNameType = 0
// maxClientHelloLen caps the ClientHello size we're willing to buffer.
maxClientHelloLen = 16384
// maxSNILen is the maximum valid DNS hostname length per RFC 1035.
maxSNILen = 253
)
// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI
// server name, and returns a wrapped connection that replays the peeked
// bytes transparently. If the data is not a valid TLS ClientHello or
// contains no SNI extension, sni is empty and err is nil.
//
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
// real server name is encrypted inside the encrypted_client_hello
// extension. This parser only reads the cleartext server_name extension
// (type 0x0000), so ECH connections return sni="" and are routed through
// the fallback path (or HTTP channel), which is the correct behavior
// for a transparent proxy that does not terminate TLS.
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
// Read the 5-byte TLS record header into a small stack-friendly buffer.
var header [tlsRecordHeaderLen]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
return "", nil, fmt.Errorf("read TLS record header: %w", err)
}
if header[0] != contentTypeHandshake {
return "", newPeekedConn(conn, header[:]), nil
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
if recordLen == 0 || recordLen > maxClientHelloLen {
return "", newPeekedConn(conn, header[:]), nil
}
// Single allocation for header + payload. The peekedConn takes
// ownership of this buffer, so no further copies are needed.
buf := make([]byte, tlsRecordHeaderLen+recordLen)
copy(buf, header[:])
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
if err != nil {
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
}
sni = extractSNI(buf[tlsRecordHeaderLen:])
return sni, newPeekedConn(conn, buf), nil
}
// extractSNI parses a TLS handshake payload to find the SNI extension.
// Returns empty string if the payload is not a ClientHello or has no SNI.
func extractSNI(payload []byte) string {
if len(payload) < 4 {
return ""
}
if payload[0] != handshakeTypeClientHello {
return ""
}
// Handshake length (3 bytes, big-endian).
handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3])
if handshakeLen > len(payload)-4 {
return ""
}
return parseSNIFromClientHello(payload[4 : 4+handshakeLen])
}
// parseSNIFromClientHello walks the ClientHello message fields to reach
// the extensions block and extract the server_name extension value.
func parseSNIFromClientHello(msg []byte) string {
// ClientHello layout:
// ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id
if len(msg) < 34 {
return ""
}
pos := 34
// Session ID (variable, 1 byte length prefix).
if pos >= len(msg) {
return ""
}
sessionIDLen := int(msg[pos])
pos++
pos += sessionIDLen
// Cipher suites (variable, 2 byte length prefix).
if pos+2 > len(msg) {
return ""
}
cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2]))
pos += 2 + cipherSuitesLen
// Compression methods (variable, 1 byte length prefix).
if pos >= len(msg) {
return ""
}
compMethodsLen := int(msg[pos])
pos++
pos += compMethodsLen
// Extensions (variable, 2 byte length prefix).
if pos+2 > len(msg) {
return ""
}
extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2]))
pos += 2
extensionsEnd := pos + extensionsLen
if extensionsEnd > len(msg) {
return ""
}
return findSNIExtension(msg[pos:extensionsEnd])
}
// findSNIExtension iterates over TLS extensions and returns the host
// name from the server_name extension, if present.
func findSNIExtension(extensions []byte) string {
pos := 0
for pos+4 <= len(extensions) {
extType := binary.BigEndian.Uint16(extensions[pos : pos+2])
extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4]))
pos += 4
if pos+extLen > len(extensions) {
return ""
}
if extType == extensionServerName {
return parseSNIExtensionData(extensions[pos : pos+extLen])
}
pos += extLen
}
return ""
}
// parseSNIExtensionData parses the ServerNameList structure inside an
// SNI extension to extract the host name.
func parseSNIExtensionData(data []byte) string {
if len(data) < 2 {
return ""
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen > len(data)-2 {
return ""
}
list := data[2 : 2+listLen]
pos := 0
for pos+3 <= len(list) {
nameType := list[pos]
nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3]))
pos += 3
if pos+nameLen > len(list) {
return ""
}
if nameType == sniHostNameType {
name := list[pos : pos+nameLen]
if nameLen > maxSNILen || bytes.ContainsRune(name, 0) {
return ""
}
return string(name)
}
pos += nameLen
}
return ""
}

View File

@@ -0,0 +1,251 @@
package tcp
import (
"crypto/tls"
"io"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPeekClientHello_ValidSNI(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
const expectedSNI = "example.com"
trailingData := []byte("trailing data after handshake")
go func() {
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: expectedSNI,
InsecureSkipVerify: true, //nolint:gosec
})
// The Handshake will send the ClientHello. It will fail because
// our server side isn't doing a real TLS handshake, but that's
// fine: we only need the ClientHello to be sent.
_ = tlsConn.Handshake()
}()
sni, wrapped, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
// Verify the wrapped connection replays the peeked bytes.
// Read the first 5 bytes (TLS record header) to confirm replay.
buf := make([]byte, 5)
n, err := wrapped.Read(buf)
require.NoError(t, err)
assert.Equal(t, 5, n)
assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type")
// Write trailing data from the client side and verify it arrives
// through the wrapped connection after the peeked bytes.
go func() {
_, _ = clientConn.Write(trailingData)
}()
// Drain the rest of the peeked ClientHello first.
peekedRest := make([]byte, 16384)
_, _ = wrapped.Read(peekedRest)
got := make([]byte, len(trailingData))
n, err = io.ReadFull(wrapped, got)
require.NoError(t, err)
assert.Equal(t, trailingData, got[:n])
}
func TestPeekClientHello_MultipleSNIs(t *testing.T) {
tests := []struct {
name string
serverName string
expectedSNI string
}{
{"simple domain", "example.com", "example.com"},
{"subdomain", "sub.example.com", "sub.example.com"},
{"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
go func() {
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: tt.serverName,
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
}()
sni, wrapped, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Equal(t, tt.expectedSNI, sni)
assert.NotNil(t, wrapped)
})
}
}
func TestPeekClientHello_NonTLSData(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
// Send plain HTTP data (not TLS).
httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
go func() {
_, _ = clientConn.Write(httpData)
}()
sni, wrapped, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
assert.NotNil(t, wrapped)
// Verify the wrapped connection still provides the original data.
buf := make([]byte, len(httpData))
n, err := io.ReadFull(wrapped, buf)
require.NoError(t, err)
assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data")
}
func TestPeekClientHello_TruncatedHeader(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer serverConn.Close()
// Write only 3 bytes then close, fewer than the 5-byte TLS header.
go func() {
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01})
clientConn.Close()
}()
_, _, err := PeekClientHello(serverConn)
assert.Error(t, err, "should error on truncated header")
}
func TestPeekClientHello_TruncatedPayload(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer serverConn.Close()
// Write a valid TLS header claiming 100 bytes, but only send 10.
go func() {
header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed
_, _ = clientConn.Write(header)
_, _ = clientConn.Write(make([]byte, 10))
clientConn.Close()
}()
_, _, err := PeekClientHello(serverConn)
assert.Error(t, err, "should error on truncated payload")
}
func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
// TLS handshake header with zero-length payload.
go func() {
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
}()
sni, wrapped, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Empty(t, sni)
assert.NotNil(t, wrapped)
}
func TestExtractSNI_InvalidPayload(t *testing.T) {
tests := []struct {
name string
payload []byte
}{
{"nil", nil},
{"empty", []byte{}},
{"too short", []byte{0x01, 0x00}},
{"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}},
{"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Empty(t, extractSNI(tt.payload))
})
}
}
func TestPeekedConn_CloseWrite(t *testing.T) {
t.Run("delegates to underlying TCPConn", func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
accepted := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err == nil {
accepted <- c
}
}()
client, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer client.Close()
server := <-accepted
defer server.Close()
wrapped := newPeekedConn(server, []byte("peeked"))
// CloseWrite should succeed on a real TCP connection.
err = wrapped.CloseWrite()
assert.NoError(t, err)
// The client should see EOF on reads after CloseWrite.
buf := make([]byte, 1)
_, err = client.Read(buf)
assert.Equal(t, io.EOF, err, "client should see EOF after half-close")
})
t.Run("no-op on non-halfcloser", func(t *testing.T) {
// net.Pipe does not implement CloseWrite.
_, server := net.Pipe()
defer server.Close()
wrapped := newPeekedConn(server, []byte("peeked"))
err := wrapped.CloseWrite()
assert.NoError(t, err, "should be no-op on non-halfcloser")
})
}
func TestPeekedConn_ReplayAndPassthrough(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
peeked := []byte("peeked-data")
subsequent := []byte("subsequent-data")
wrapped := newPeekedConn(serverConn, peeked)
go func() {
_, _ = clientConn.Write(subsequent)
}()
// Read should return peeked data first.
buf := make([]byte, len(peeked))
n, err := io.ReadFull(wrapped, buf)
require.NoError(t, err)
assert.Equal(t, peeked, buf[:n])
// Then subsequent data from the real connection.
buf = make([]byte, len(subsequent))
n, err = io.ReadFull(wrapped, buf)
require.NoError(t, err)
assert.Equal(t, subsequent, buf[:n])
}

View File

@@ -1,5 +1,56 @@
// Package types defines common types used across the proxy package.
package types
import (
"context"
"net"
"time"
)
// AccountID represents a unique identifier for a NetBird account.
type AccountID string
// ServiceID represents a unique identifier for a proxy service.
type ServiceID string
// ServiceMode describes how a reverse proxy service is exposed.
type ServiceMode string
const (
ServiceModeHTTP ServiceMode = "http"
ServiceModeTCP ServiceMode = "tcp"
ServiceModeUDP ServiceMode = "udp"
ServiceModeTLS ServiceMode = "tls"
)
// IsL4 returns true for TCP, UDP, and TLS modes.
func (m ServiceMode) IsL4() bool {
return m == ServiceModeTCP || m == ServiceModeUDP || m == ServiceModeTLS
}
// RelayDirection indicates the direction of a relayed packet.
type RelayDirection string
const (
RelayDirectionClientToBackend RelayDirection = "client_to_backend"
RelayDirectionBackendToClient RelayDirection = "backend_to_client"
)
// DialContextFunc dials a backend through the WireGuard tunnel.
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
// dialTimeoutKey is the context key for a per-request dial timeout.
type dialTimeoutKey struct{}
// WithDialTimeout returns a context carrying a dial timeout that
// DialContext wrappers can use to scope the timeout to just the
// connection establishment phase.
func WithDialTimeout(ctx context.Context, d time.Duration) context.Context {
return context.WithValue(ctx, dialTimeoutKey{}, d)
}
// DialTimeoutFromContext returns the dial timeout from the context, if set.
func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) {
d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration)
return d, ok && d > 0
}

View File

@@ -0,0 +1,54 @@
package types
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestServiceMode_IsL4(t *testing.T) {
tests := []struct {
mode ServiceMode
want bool
}{
{ServiceModeHTTP, false},
{ServiceModeTCP, true},
{ServiceModeUDP, true},
{ServiceModeTLS, true},
{ServiceMode("unknown"), false},
}
for _, tt := range tests {
t.Run(string(tt.mode), func(t *testing.T) {
assert.Equal(t, tt.want, tt.mode.IsL4())
})
}
}
func TestDialTimeoutContext(t *testing.T) {
t.Run("round trip", func(t *testing.T) {
ctx := WithDialTimeout(context.Background(), 5*time.Second)
d, ok := DialTimeoutFromContext(ctx)
assert.True(t, ok)
assert.Equal(t, 5*time.Second, d)
})
t.Run("missing", func(t *testing.T) {
_, ok := DialTimeoutFromContext(context.Background())
assert.False(t, ok)
})
t.Run("zero returns false", func(t *testing.T) {
ctx := WithDialTimeout(context.Background(), 0)
_, ok := DialTimeoutFromContext(ctx)
assert.False(t, ok, "zero duration should return ok=false")
})
t.Run("negative returns false", func(t *testing.T) {
ctx := WithDialTimeout(context.Background(), -1*time.Second)
_, ok := DialTimeoutFromContext(ctx)
assert.False(t, ok, "negative duration should return ok=false")
})
}

496
proxy/internal/udp/relay.go Normal file
View File

@@ -0,0 +1,496 @@
package udp
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/types"
)
const (
// DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup.
DefaultSessionTTL = 30 * time.Second
// cleanupInterval is how often the cleaner goroutine runs.
cleanupInterval = time.Minute
// maxPacketSize is the maximum UDP packet size we'll handle.
maxPacketSize = 65535
// DefaultMaxSessions is the default cap on concurrent UDP sessions per relay.
DefaultMaxSessions = 1024
// sessionCreateRate limits new session creation per second.
sessionCreateRate = 50
// sessionCreateBurst is the burst allowance for session creation.
sessionCreateBurst = 100
// defaultDialTimeout is the fallback dial timeout for backend connections.
defaultDialTimeout = 30 * time.Second
)
// l4Logger sends layer-4 access log entries to the management server.
type l4Logger interface {
LogL4(entry accesslog.L4Entry)
}
// SessionObserver receives callbacks for UDP session lifecycle events.
// All methods must be safe for concurrent use.
type SessionObserver interface {
UDPSessionStarted(accountID types.AccountID)
UDPSessionEnded(accountID types.AccountID)
UDPSessionDialError(accountID types.AccountID)
UDPSessionRejected(accountID types.AccountID)
UDPPacketRelayed(direction types.RelayDirection, bytes int)
}
// clientAddr is a typed key for UDP session lookups.
type clientAddr string
// Relay listens for incoming UDP packets on a dedicated port and
// maintains per-client sessions that relay packets to a backend
// through the WireGuard tunnel.
type Relay struct {
logger *log.Entry
listener net.PacketConn
target string
domain string
accountID types.AccountID
serviceID types.ServiceID
dialFunc types.DialContextFunc
dialTimeout time.Duration
sessionTTL time.Duration
maxSessions int
mu sync.RWMutex
sessions map[clientAddr]*session
bufPool sync.Pool
sessLimiter *rate.Limiter
sessWg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
observer SessionObserver
accessLog l4Logger
}
type session struct {
backend net.Conn
addr net.Addr
createdAt time.Time
// lastSeen stores the last activity timestamp as unix nanoseconds.
lastSeen atomic.Int64
cancel context.CancelFunc
// bytesIn tracks total bytes received from the client.
bytesIn atomic.Int64
// bytesOut tracks total bytes sent back to the client.
bytesOut atomic.Int64
}
func (s *session) updateLastSeen() {
s.lastSeen.Store(time.Now().UnixNano())
}
func (s *session) idleDuration() time.Duration {
return time.Since(time.Unix(0, s.lastSeen.Load()))
}
// RelayConfig holds the configuration for a UDP relay.
type RelayConfig struct {
Logger *log.Entry
Listener net.PacketConn
Target string
Domain string
AccountID types.AccountID
ServiceID types.ServiceID
DialFunc types.DialContextFunc
DialTimeout time.Duration
SessionTTL time.Duration
MaxSessions int
AccessLog l4Logger
}
// New creates a UDP relay for the given listener and backend target.
// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions.
// DialTimeout controls how long to wait for backend connections; use 0 for default.
// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL.
func New(parentCtx context.Context, cfg RelayConfig) *Relay {
maxSessions := cfg.MaxSessions
dialTimeout := cfg.DialTimeout
sessionTTL := cfg.SessionTTL
if maxSessions <= 0 {
maxSessions = DefaultMaxSessions
}
if dialTimeout <= 0 {
dialTimeout = defaultDialTimeout
}
if sessionTTL <= 0 {
sessionTTL = DefaultSessionTTL
}
ctx, cancel := context.WithCancel(parentCtx)
return &Relay{
logger: cfg.Logger,
listener: cfg.Listener,
target: cfg.Target,
domain: cfg.Domain,
accountID: cfg.AccountID,
serviceID: cfg.ServiceID,
accessLog: cfg.AccessLog,
dialFunc: cfg.DialFunc,
dialTimeout: dialTimeout,
sessionTTL: sessionTTL,
maxSessions: maxSessions,
sessions: make(map[clientAddr]*session),
bufPool: sync.Pool{
New: func() any {
buf := make([]byte, maxPacketSize)
return &buf
},
},
sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst),
ctx: ctx,
cancel: cancel,
}
}
// ServiceID returns the service ID associated with this relay.
func (r *Relay) ServiceID() types.ServiceID {
return r.serviceID
}
// SetObserver sets the session lifecycle observer. Must be called before Serve.
func (r *Relay) SetObserver(obs SessionObserver) {
r.observer = obs
}
// Serve starts the relay loop. It blocks until the context is canceled
// or the listener is closed.
func (r *Relay) Serve() {
go r.cleanupLoop()
for {
bufp := r.bufPool.Get().(*[]byte)
buf := *bufp
n, addr, err := r.listener.ReadFrom(buf)
if err != nil {
r.bufPool.Put(bufp)
if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
r.logger.Debugf("UDP read: %v", err)
continue
}
data := buf[:n]
sess, err := r.getOrCreateSession(addr)
if err != nil {
r.bufPool.Put(bufp)
r.logger.Debugf("create UDP session for %s: %v", addr, err)
continue
}
sess.updateLastSeen()
nw, err := sess.backend.Write(data)
if err != nil {
r.bufPool.Put(bufp)
if !netutil.IsExpectedError(err) {
r.logger.Debugf("UDP write to backend for %s: %v", addr, err)
}
r.removeSession(sess)
continue
}
sess.bytesIn.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
}
r.bufPool.Put(bufp)
}
}
// getOrCreateSession returns an existing session or creates a new one.
func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
key := clientAddr(addr.String())
r.mu.RLock()
sess, ok := r.sessions[key]
r.mu.RUnlock()
if ok && sess != nil {
return sess, nil
}
// Check before taking the write lock: if the relay is shutting down,
// don't create new sessions. This prevents orphaned goroutines when
// Serve() processes a packet that was already read before Close().
if r.ctx.Err() != nil {
return nil, r.ctx.Err()
}
r.mu.Lock()
if sess, ok = r.sessions[key]; ok && sess != nil {
r.mu.Unlock()
return sess, nil
}
if ok {
// Another goroutine is dialing for this key, skip.
r.mu.Unlock()
return nil, fmt.Errorf("session dial in progress for %s", key)
}
if len(r.sessions) >= r.maxSessions {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
}
if !r.sessLimiter.Allow() {
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionRejected(r.accountID)
}
return nil, fmt.Errorf("session creation rate limited")
}
// Reserve the slot with a nil session so concurrent callers for the same
// key see it exists and wait. Release the lock before dialing.
r.sessions[key] = nil
r.mu.Unlock()
dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout)
backend, err := r.dialFunc(dialCtx, "udp", r.target)
dialCancel()
if err != nil {
r.mu.Lock()
delete(r.sessions, key)
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionDialError(r.accountID)
}
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
}
sessCtx, sessCancel := context.WithCancel(r.ctx)
sess = &session{
backend: backend,
addr: addr,
createdAt: time.Now(),
cancel: sessCancel,
}
sess.updateLastSeen()
r.mu.Lock()
r.sessions[key] = sess
r.mu.Unlock()
if r.observer != nil {
r.observer.UDPSessionStarted(r.accountID)
}
r.sessWg.Go(func() {
r.relayBackendToClient(sessCtx, sess)
})
r.logger.Debugf("UDP session created for %s", addr)
return sess, nil
}
// relayBackendToClient reads packets from the backend and writes them
// back to the client through the public-facing listener.
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
bufp := r.bufPool.Get().(*[]byte)
defer r.bufPool.Put(bufp)
defer r.removeSession(sess)
for ctx.Err() == nil {
data, ok := r.readBackendPacket(sess, *bufp)
if !ok {
return
}
if data == nil {
continue
}
sess.updateLastSeen()
nw, err := r.listener.WriteTo(data, sess.addr)
if err != nil {
if !netutil.IsExpectedError(err) {
r.logger.Debugf("UDP write to client %s: %v", sess.addr, err)
}
return
}
sess.bytesOut.Add(int64(nw))
if r.observer != nil {
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
}
}
}
// readBackendPacket reads one packet from the backend with an idle deadline.
// Returns (data, true) on success, (nil, true) on idle timeout that should
// retry, or (nil, false) when the session should be torn down.
func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) {
if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil {
r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err)
return nil, false
}
n, err := sess.backend.Read(buf)
if err != nil {
if netutil.IsTimeout(err) {
if sess.idleDuration() > r.sessionTTL {
return nil, false
}
return nil, true
}
if !netutil.IsExpectedError(err) {
r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err)
}
return nil, false
}
return buf[:n], true
}
// cleanupLoop periodically removes idle sessions.
func (r *Relay) cleanupLoop() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-r.ctx.Done():
return
case <-ticker.C:
r.cleanupIdleSessions()
}
}
}
// cleanupIdleSessions closes sessions that have been idle for too long.
func (r *Relay) cleanupIdleSessions() {
var expired []*session
r.mu.Lock()
for key, sess := range r.sessions {
if sess == nil {
continue
}
idle := sess.idleDuration()
if idle > r.sessionTTL {
r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load())
delete(r.sessions, key)
sess.cancel()
if err := sess.backend.Close(); err != nil {
r.logger.Debugf("close idle session %s backend: %v", sess.addr, err)
}
expired = append(expired, sess)
}
}
r.mu.Unlock()
for _, sess := range expired {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
}
// removeSession removes a session from the map if it still matches the
// given pointer. This is safe to call concurrently with cleanupIdleSessions
// because the identity check prevents double-close when both paths race.
func (r *Relay) removeSession(sess *session) {
r.mu.Lock()
key := clientAddr(sess.addr.String())
removed := r.sessions[key] == sess
if removed {
delete(r.sessions, key)
sess.cancel()
if err := sess.backend.Close(); err != nil {
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
}
}
r.mu.Unlock()
if removed {
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
}
// logSessionEnd sends an access log entry for a completed UDP session.
func (r *Relay) logSessionEnd(sess *session) {
if r.accessLog == nil {
return
}
var sourceIP netip.Addr
if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
r.accessLog.LogL4(accesslog.L4Entry{
AccountID: r.accountID,
ServiceID: r.serviceID,
Protocol: accesslog.ProtocolUDP,
Host: r.domain,
SourceIP: sourceIP,
DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(),
BytesUpload: sess.bytesIn.Load(),
BytesDownload: sess.bytesOut.Load(),
})
}
// Close stops the relay, waits for all session goroutines to exit,
// and cleans up remaining sessions.
func (r *Relay) Close() {
r.cancel()
if err := r.listener.Close(); err != nil {
r.logger.Debugf("close UDP listener: %v", err)
}
var closedSessions []*session
r.mu.Lock()
for key, sess := range r.sessions {
if sess == nil {
delete(r.sessions, key)
continue
}
r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)",
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
sess.cancel()
if err := sess.backend.Close(); err != nil {
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
}
delete(r.sessions, key)
closedSessions = append(closedSessions, sess)
}
r.mu.Unlock()
for _, sess := range closedSessions {
if r.observer != nil {
r.observer.UDPSessionEnded(r.accountID)
}
r.logSessionEnd(sess)
}
r.sessWg.Wait()
}

View File

@@ -0,0 +1,493 @@
package udp
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestRelay_BasicPacketExchange(t *testing.T) {
// Set up a UDP backend that echoes packets.
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
// Set up the relay's public-facing listener.
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
backendAddr := backend.LocalAddr().String()
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc})
go relay.Serve()
defer relay.Close()
// Create a client and send a packet to the relay.
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err)
defer client.Close()
testData := []byte("hello UDP relay")
_, err = client.Write(testData)
require.NoError(t, err)
// Read the echoed response.
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := client.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "should receive echoed packet")
}
func TestRelay_MultipleClients(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay.Serve()
defer relay.Close()
// Two clients, each should get their own session.
for i, msg := range []string{"client-1", "client-2"} {
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err, "client %d", i)
defer client.Close()
_, err = client.Write([]byte(msg))
require.NoError(t, err)
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := client.Read(buf)
require.NoError(t, err, "client %d read", i)
assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i)
}
// Verify two sessions were created.
relay.mu.RLock()
sessionCount := len(relay.sessions)
relay.mu.RUnlock()
assert.Equal(t, 2, sessionCount, "should have two sessions")
}
func TestRelay_Close(t *testing.T) {
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc})
done := make(chan struct{})
go func() {
relay.Serve()
close(done)
}()
relay.Close()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Serve did not return after Close")
}
}
func TestRelay_SessionCleanup(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay.Serve()
defer relay.Close()
// Create a session.
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err)
_, err = client.Write([]byte("hello"))
require.NoError(t, err)
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
_, err = client.Read(buf)
require.NoError(t, err)
client.Close()
// Verify session exists.
relay.mu.RLock()
assert.Equal(t, 1, len(relay.sessions))
relay.mu.RUnlock()
// Make session appear idle by setting lastSeen to the past.
relay.mu.Lock()
for _, sess := range relay.sessions {
sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano())
}
relay.mu.Unlock()
// Trigger cleanup manually.
relay.cleanupIdleSessions()
relay.mu.RLock()
assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up")
relay.mu.RUnlock()
}
// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new
// one on the same port works cleanly (simulates port mapping modify cycle).
func TestRelay_CloseAndRecreate(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
// First relay.
ln1, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay1.Serve()
client1, err := net.Dial("udp", ln1.LocalAddr().String())
require.NoError(t, err)
_, err = client1.Write([]byte("relay1"))
require.NoError(t, err)
require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1024)
n, err := client1.Read(buf)
require.NoError(t, err)
assert.Equal(t, "relay1", string(buf[:n]))
client1.Close()
// Close first relay.
relay1.Close()
// Second relay on same port.
port := ln1.LocalAddr().(*net.UDPAddr).Port
ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port))
require.NoError(t, err)
relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
go relay2.Serve()
defer relay2.Close()
client2, err := net.Dial("udp", ln2.LocalAddr().String())
require.NoError(t, err)
defer client2.Close()
_, err = client2.Write([]byte("relay2"))
require.NoError(t, err)
require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second)))
n, err = client2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port")
}
func TestRelay_SessionLimit(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
// Create a relay with a max of 2 sessions.
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2})
go relay.Serve()
defer relay.Close()
// Create 2 clients to fill up the session limit.
for i := range 2 {
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err, "client %d", i)
defer client.Close()
_, err = client.Write([]byte("hello"))
require.NoError(t, err)
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1024)
_, err = client.Read(buf)
require.NoError(t, err, "client %d should get response", i)
}
relay.mu.RLock()
assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions")
relay.mu.RUnlock()
// Third client should get its packet dropped (session creation fails).
client3, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err)
defer client3.Close()
_, err = client3.Write([]byte("should be dropped"))
require.NoError(t, err)
require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond)))
buf := make([]byte, 1024)
_, err = client3.Read(buf)
assert.Error(t, err, "third client should time out because session was rejected")
relay.mu.RLock()
assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit")
relay.mu.RUnlock()
}
// testObserver records UDP session lifecycle events for test assertions.
type testObserver struct {
mu sync.Mutex
started int
ended int
rejected int
dialErr int
packets int
bytes int
}
func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() }
func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() }
func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() }
func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() }
func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) {
o.mu.Lock()
o.packets++
o.bytes += b
o.mu.Unlock()
}
func TestRelay_CloseFiresObserverEnded(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
obs := &testObserver{}
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc})
relay.SetObserver(obs)
go relay.Serve()
// Create two sessions.
for i := range 2 {
client, err := net.Dial("udp", listener.LocalAddr().String())
require.NoError(t, err, "client %d", i)
_, err = client.Write([]byte("hello"))
require.NoError(t, err)
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1024)
_, err = client.Read(buf)
require.NoError(t, err)
client.Close()
}
obs.mu.Lock()
assert.Equal(t, 2, obs.started, "should have 2 started events")
obs.mu.Unlock()
// Close should fire UDPSessionEnded for all remaining sessions.
relay.Close()
obs.mu.Lock()
assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session")
obs.mu.Unlock()
}
func TestRelay_SessionRateLimit(t *testing.T) {
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer backend.Close()
go func() {
buf := make([]byte, 65535)
for {
n, addr, err := backend.ReadFrom(buf)
if err != nil {
return
}
_, _ = backend.WriteTo(buf[:n], addr)
}
}()
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := log.NewEntry(log.StandardLogger())
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
obs := &testObserver{}
// High max sessions (1000) but the relay uses a rate limiter internally
// (default: 50/s burst 100). We exhaust the burst by creating sessions
// rapidly, then verify that subsequent creates are rejected.
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000})
relay.SetObserver(obs)
go relay.Serve()
defer relay.Close()
// Exhaust the burst by calling getOrCreateSession directly with
// synthetic addresses. This is faster than real UDP round-trips.
for i := range sessionCreateBurst + 20 {
addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i}
_, _ = relay.getOrCreateSession(addr)
}
obs.mu.Lock()
rejected := obs.rejected
obs.mu.Unlock()
assert.Greater(t, rejected, 0, "some sessions should be rate-limited")
}

View File

@@ -243,6 +243,10 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string {
return nil
}
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
return nil
}
// storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct {
store store.Store
@@ -505,15 +509,15 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
nil,
"",
0,
mapping.GetAccountId(),
mapping.GetId(),
proxytypes.AccountID(mapping.GetAccountId()),
proxytypes.ServiceID(mapping.GetId()),
)
require.NoError(t, err)
// Apply to real proxy (idempotent)
proxyHandler.AddMapping(proxy.Mapping{
Host: mapping.GetDomain(),
ID: mapping.GetId(),
ID: proxytypes.ServiceID(mapping.GetId()),
AccountID: proxytypes.AccountID(mapping.GetAccountId()),
})
}

File diff suppressed because it is too large Load Diff