mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:24:18 -04:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -34,30 +35,32 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
logLevel string
|
||||
debugLogs bool
|
||||
mgmtAddr string
|
||||
addr string
|
||||
proxyDomain string
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeEABKID string
|
||||
acmeEABHMACKey string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wildcardCertDir string
|
||||
wgPort int
|
||||
proxyProtocol bool
|
||||
preSharedKey string
|
||||
logLevel string
|
||||
debugLogs bool
|
||||
mgmtAddr string
|
||||
addr string
|
||||
proxyDomain string
|
||||
defaultDialTimeout time.Duration
|
||||
certDir string
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeEABKID string
|
||||
acmeEABHMACKey string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
healthAddr string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
wildcardCertDir string
|
||||
wgPort uint16
|
||||
proxyProtocol bool
|
||||
preSharedKey string
|
||||
supportsCustomPorts bool
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -92,9 +95,11 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||
rootCmd.Flags().StringVar(&wildcardCertDir, "wildcard-cert-dir", envStringOrDefault("NB_PROXY_WILDCARD_CERT_DIR", ""), "Directory containing wildcard certificate pairs (<name>.crt/<name>.key). Wildcard patterns are extracted from SANs automatically")
|
||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||
rootCmd.Flags().Uint16Var(&wgPort, "wg-port", envUint16OrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
|
||||
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
|
||||
rootCmd.Flags().DurationVar(&defaultDialTimeout, "default-dial-timeout", envDurationOrDefault("NB_PROXY_DEFAULT_DIAL_TIMEOUT", 0), "Default backend dial timeout when no per-service timeout is set (e.g. 30s)")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -171,6 +176,8 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
WireguardPort: wgPort,
|
||||
ProxyProtocol: proxyProtocol,
|
||||
PreSharedKey: preSharedKey,
|
||||
SupportsCustomPorts: supportsCustomPorts,
|
||||
DefaultDialTimeout: defaultDialTimeout,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
@@ -203,12 +210,24 @@ func envStringOrDefault(key string, def string) string {
|
||||
return v
|
||||
}
|
||||
|
||||
func envIntOrDefault(key string, def int) int {
|
||||
func envUint16OrDefault(key string, def uint16) uint16 {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
parsed, err := strconv.Atoi(v)
|
||||
parsed, err := strconv.ParseUint(v, 10, 16)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return uint16(parsed)
|
||||
}
|
||||
|
||||
func envDurationOrDefault(key string, def time.Duration) time.Duration {
|
||||
v, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return def
|
||||
}
|
||||
parsed, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
|
||||
@@ -38,11 +38,18 @@ func (m *mockMappingStream) Context() context.Context { return context.Backgroun
|
||||
func (m *mockMappingStream) SendMsg(any) error { return nil }
|
||||
func (m *mockMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func closedChan() chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
routerReady: closedChan(),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
@@ -62,6 +69,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
routerReady: closedChan(),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
@@ -78,7 +86,8 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
|
||||
|
||||
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -19,6 +21,7 @@ const (
|
||||
bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB
|
||||
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
|
||||
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
|
||||
logSendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type domainUsage struct {
|
||||
@@ -79,22 +82,63 @@ func (l *Logger) Close() {
|
||||
|
||||
type logEntry struct {
|
||||
ID string
|
||||
AccountID string
|
||||
ServiceId string
|
||||
AccountID types.AccountID
|
||||
ServiceId types.ServiceID
|
||||
Host string
|
||||
Path string
|
||||
DurationMs int64
|
||||
Method string
|
||||
ResponseCode int32
|
||||
SourceIp string
|
||||
SourceIP netip.Addr
|
||||
AuthMechanism string
|
||||
UserId string
|
||||
AuthSuccess bool
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
Protocol Protocol
|
||||
}
|
||||
|
||||
func (l *Logger) log(ctx context.Context, entry logEntry) {
|
||||
// Protocol identifies the transport protocol of an access log entry.
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
ProtocolHTTP Protocol = "http"
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
ProtocolUDP Protocol = "udp"
|
||||
ProtocolTLS Protocol = "tls"
|
||||
)
|
||||
|
||||
// L4Entry holds the data for a layer-4 (TCP/UDP) access log entry.
|
||||
type L4Entry struct {
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
Protocol Protocol
|
||||
Host string // SNI hostname or listen address
|
||||
SourceIP netip.Addr
|
||||
DurationMs int64
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
}
|
||||
|
||||
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
|
||||
// The call is non-blocking: the gRPC send happens in a background goroutine.
|
||||
func (l *Logger) LogL4(entry L4Entry) {
|
||||
le := logEntry{
|
||||
ID: xid.New().String(),
|
||||
AccountID: entry.AccountID,
|
||||
ServiceId: entry.ServiceID,
|
||||
Protocol: entry.Protocol,
|
||||
Host: entry.Host,
|
||||
SourceIP: entry.SourceIP,
|
||||
DurationMs: entry.DurationMs,
|
||||
BytesUpload: entry.BytesUpload,
|
||||
BytesDownload: entry.BytesDownload,
|
||||
}
|
||||
l.log(le)
|
||||
l.trackUsage(entry.Host, entry.BytesUpload+entry.BytesDownload)
|
||||
}
|
||||
|
||||
func (l *Logger) log(entry logEntry) {
|
||||
// Fire off the log request in a separate routine.
|
||||
// This increases the possibility of losing a log message
|
||||
// (although it should still get logged in the event of an error),
|
||||
@@ -105,31 +149,37 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
|
||||
// allow for resolving that on the server.
|
||||
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
|
||||
go func() {
|
||||
logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
logCtx, cancel := context.WithTimeout(context.Background(), logSendTimeout)
|
||||
defer cancel()
|
||||
if entry.AuthMechanism != auth.MethodOIDC.String() {
|
||||
entry.UserId = ""
|
||||
}
|
||||
|
||||
var sourceIP string
|
||||
if entry.SourceIP.IsValid() {
|
||||
sourceIP = entry.SourceIP.String()
|
||||
}
|
||||
|
||||
if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{
|
||||
Log: &proto.AccessLog{
|
||||
LogId: entry.ID,
|
||||
AccountId: entry.AccountID,
|
||||
AccountId: string(entry.AccountID),
|
||||
Timestamp: now,
|
||||
ServiceId: entry.ServiceId,
|
||||
ServiceId: string(entry.ServiceId),
|
||||
Host: entry.Host,
|
||||
Path: entry.Path,
|
||||
DurationMs: entry.DurationMs,
|
||||
Method: entry.Method,
|
||||
ResponseCode: entry.ResponseCode,
|
||||
SourceIp: entry.SourceIp,
|
||||
SourceIp: sourceIP,
|
||||
AuthMechanism: entry.AuthMechanism,
|
||||
UserId: entry.UserId,
|
||||
AuthSuccess: entry.AuthSuccess,
|
||||
BytesUpload: entry.BytesUpload,
|
||||
BytesDownload: entry.BytesDownload,
|
||||
Protocol: string(entry.Protocol),
|
||||
},
|
||||
}); err != nil {
|
||||
// If it fails to send on the gRPC connection, then at least log it to the error log.
|
||||
l.logger.WithFields(log.Fields{
|
||||
"service_id": entry.ServiceId,
|
||||
"host": entry.Host,
|
||||
@@ -137,7 +187,7 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
|
||||
"duration": entry.DurationMs,
|
||||
"method": entry.Method,
|
||||
"response_code": entry.ResponseCode,
|
||||
"source_ip": entry.SourceIp,
|
||||
"source_ip": sourceIP,
|
||||
"auth_mechanism": entry.AuthMechanism,
|
||||
"user_id": entry.UserId,
|
||||
"auth_success": entry.AuthSuccess,
|
||||
|
||||
@@ -67,23 +67,24 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
entry := logEntry{
|
||||
ID: requestID,
|
||||
ServiceId: capturedData.GetServiceId(),
|
||||
AccountID: string(capturedData.GetAccountId()),
|
||||
AccountID: capturedData.GetAccountId(),
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(sw.status),
|
||||
SourceIp: sourceIp,
|
||||
SourceIP: sourceIp,
|
||||
AuthMechanism: capturedData.GetAuthMethod(),
|
||||
UserId: capturedData.GetUserID(),
|
||||
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
|
||||
BytesUpload: bytesUpload,
|
||||
BytesDownload: bytesDownload,
|
||||
Protocol: ProtocolHTTP,
|
||||
}
|
||||
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
|
||||
|
||||
l.log(r.Context(), entry)
|
||||
l.log(entry)
|
||||
|
||||
// Track usage for cost monitoring (upload + download) by domain
|
||||
l.trackUsage(host, bytesUpload+bytesDownload)
|
||||
|
||||
@@ -11,6 +11,6 @@ import (
|
||||
// proxy configuration. When trustedProxies is non-empty and the direct
|
||||
// connection is from a trusted source, it walks X-Forwarded-For right-to-left
|
||||
// skipping trusted IPs. Otherwise it returns RemoteAddr directly.
|
||||
func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string {
|
||||
func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) netip.Addr {
|
||||
return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ import (
|
||||
var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2}
|
||||
|
||||
type certificateNotifier interface {
|
||||
NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error
|
||||
NotifyCertificateIssued(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, domain string) error
|
||||
}
|
||||
|
||||
type domainState int
|
||||
@@ -42,8 +43,8 @@ const (
|
||||
)
|
||||
|
||||
type domainInfo struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
accountID types.AccountID
|
||||
serviceID types.ServiceID
|
||||
state domainState
|
||||
err string
|
||||
}
|
||||
@@ -301,7 +302,7 @@ func (mgr *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate
|
||||
// When AddDomain returns true the caller is responsible for sending any
|
||||
// certificate-ready notifications after the surrounding operation (e.g.
|
||||
// mapping update) has committed successfully.
|
||||
func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) (wildcardHit bool) {
|
||||
func (mgr *Manager) AddDomain(d domain.Domain, accountID types.AccountID, serviceID types.ServiceID) (wildcardHit bool) {
|
||||
name := d.PunycodeString()
|
||||
if e := mgr.findWildcardEntry(name); e != nil {
|
||||
mgr.mu.Lock()
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func TestHostPolicy(t *testing.T) {
|
||||
mgr, err := NewManager(ManagerConfig{CertDir: t.TempDir(), ACMEURL: "https://acme.example.com/directory"}, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.AddDomain("example.com", "acc1", "rp1")
|
||||
mgr.AddDomain("example.com", types.AccountID("acc1"), types.ServiceID("rp1"))
|
||||
|
||||
// Wait for the background prefetch goroutine to finish so the temp dir
|
||||
// can be cleaned up without a race.
|
||||
@@ -92,8 +94,8 @@ func TestDomainStates(t *testing.T) {
|
||||
|
||||
// AddDomain starts as pending, then the prefetch goroutine will fail
|
||||
// (no real ACME server) and transition to failed.
|
||||
mgr.AddDomain("a.example.com", "acc1", "rp1")
|
||||
mgr.AddDomain("b.example.com", "acc1", "rp1")
|
||||
mgr.AddDomain("a.example.com", types.AccountID("acc1"), types.ServiceID("rp1"))
|
||||
mgr.AddDomain("b.example.com", types.AccountID("acc1"), types.ServiceID("rp1"))
|
||||
|
||||
assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered")
|
||||
|
||||
@@ -209,12 +211,12 @@ func TestWildcardAddDomainSkipsACME(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a wildcard-matching domain — should be immediately ready.
|
||||
mgr.AddDomain("foo.example.com", "acc1", "svc1")
|
||||
mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1"))
|
||||
assert.Equal(t, 0, mgr.PendingCerts(), "wildcard domain should not be pending")
|
||||
assert.Equal(t, []string{"foo.example.com"}, mgr.ReadyDomains())
|
||||
|
||||
// Add a non-wildcard domain — should go through ACME (pending then failed).
|
||||
mgr.AddDomain("other.net", "acc2", "svc2")
|
||||
mgr.AddDomain("other.net", types.AccountID("acc2"), types.ServiceID("svc2"))
|
||||
assert.Equal(t, 2, mgr.TotalDomains())
|
||||
|
||||
// Wait for the ACME prefetch to fail.
|
||||
@@ -234,7 +236,7 @@ func TestWildcardGetCertificate(t *testing.T) {
|
||||
mgr, err := NewManager(ManagerConfig{CertDir: acmeDir, ACMEURL: "https://acme.example.com/directory", WildcardDir: wcDir}, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr.AddDomain("foo.example.com", "acc1", "svc1")
|
||||
mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1"))
|
||||
|
||||
// GetCertificate for a wildcard-matching domain should return the static cert.
|
||||
cert, err := mgr.GetCertificate(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
||||
@@ -255,8 +257,8 @@ func TestMultipleWildcards(t *testing.T) {
|
||||
assert.ElementsMatch(t, []string{"*.example.com", "*.other.org"}, mgr.WildcardPatterns())
|
||||
|
||||
// Both wildcards should resolve.
|
||||
mgr.AddDomain("foo.example.com", "acc1", "svc1")
|
||||
mgr.AddDomain("bar.other.org", "acc2", "svc2")
|
||||
mgr.AddDomain("foo.example.com", types.AccountID("acc1"), types.ServiceID("svc1"))
|
||||
mgr.AddDomain("bar.other.org", types.AccountID("acc2"), types.ServiceID("svc2"))
|
||||
|
||||
assert.Equal(t, 0, mgr.PendingCerts())
|
||||
assert.ElementsMatch(t, []string{"foo.example.com", "bar.other.org"}, mgr.ReadyDomains())
|
||||
@@ -271,7 +273,7 @@ func TestMultipleWildcards(t *testing.T) {
|
||||
assert.Contains(t, cert2.Leaf.DNSNames, "*.other.org")
|
||||
|
||||
// Non-matching domain falls through to ACME.
|
||||
mgr.AddDomain("custom.net", "acc3", "svc3")
|
||||
mgr.AddDomain("custom.net", types.AccountID("acc3"), types.ServiceID("svc3"))
|
||||
assert.Eventually(t, func() bool {
|
||||
return mgr.PendingCerts() == 0
|
||||
}, 30*time.Second, 100*time.Millisecond)
|
||||
|
||||
@@ -44,8 +44,8 @@ type DomainConfig struct {
|
||||
Schemes []Scheme
|
||||
SessionPublicKey ed25519.PublicKey
|
||||
SessionExpiration time.Duration
|
||||
AccountID string
|
||||
ServiceID string
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
@@ -124,7 +124,7 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetAccountId(config.AccountID)
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
}
|
||||
@@ -275,7 +275,7 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error {
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -17,14 +18,14 @@ type urlGenerator interface {
|
||||
}
|
||||
|
||||
type OIDC struct {
|
||||
id string
|
||||
accountId string
|
||||
id types.ServiceID
|
||||
accountId types.AccountID
|
||||
forwardedProto string
|
||||
client urlGenerator
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC {
|
||||
func NewOIDC(client urlGenerator, id types.ServiceID, accountId types.AccountID, forwardedProto string) OIDC {
|
||||
return OIDC{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
@@ -53,8 +54,8 @@ func (o OIDC) Authenticate(r *http.Request) (string, string, error) {
|
||||
}
|
||||
|
||||
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
|
||||
Id: o.id,
|
||||
AccountId: o.accountId,
|
||||
Id: string(o.id),
|
||||
AccountId: string(o.accountId),
|
||||
RedirectUrl: redirectURL.String(),
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -5,17 +5,19 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const passwordFormId = "password"
|
||||
|
||||
type Password struct {
|
||||
id, accountId string
|
||||
client authenticator
|
||||
id types.ServiceID
|
||||
accountId types.AccountID
|
||||
client authenticator
|
||||
}
|
||||
|
||||
func NewPassword(client authenticator, id, accountId string) Password {
|
||||
func NewPassword(client authenticator, id types.ServiceID, accountId types.AccountID) Password {
|
||||
return Password{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
@@ -41,8 +43,8 @@ func (p Password) Authenticate(r *http.Request) (string, string, error) {
|
||||
}
|
||||
|
||||
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: p.id,
|
||||
AccountId: p.accountId,
|
||||
Id: string(p.id),
|
||||
AccountId: string(p.accountId),
|
||||
Request: &proto.AuthenticateRequest_Password{
|
||||
Password: &proto.PasswordRequest{
|
||||
Password: password,
|
||||
|
||||
@@ -5,17 +5,19 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const pinFormId = "pin"
|
||||
|
||||
type Pin struct {
|
||||
id, accountId string
|
||||
client authenticator
|
||||
id types.ServiceID
|
||||
accountId types.AccountID
|
||||
client authenticator
|
||||
}
|
||||
|
||||
func NewPin(client authenticator, id, accountId string) Pin {
|
||||
func NewPin(client authenticator, id types.ServiceID, accountId types.AccountID) Pin {
|
||||
return Pin{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
@@ -41,8 +43,8 @@ func (p Pin) Authenticate(r *http.Request) (string, string, error) {
|
||||
}
|
||||
|
||||
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: p.id,
|
||||
AccountId: p.accountId,
|
||||
Id: string(p.id),
|
||||
AccountId: string(p.accountId),
|
||||
Request: &proto.AuthenticateRequest_Pin{
|
||||
Pin: &proto.PinRequest{
|
||||
Pin: pin,
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
tracker *HijackTracker
|
||||
host string
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.tracker.conns.Delete(c)
|
||||
c.tracker.remove(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
@@ -22,6 +23,7 @@ func (c *trackedConn) Close() error {
|
||||
type trackingWriter struct {
|
||||
http.ResponseWriter
|
||||
tracker *HijackTracker
|
||||
host string
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
@@ -33,8 +35,8 @@ func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||
w.tracker.conns.Store(tc, struct{}{})
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker, host: w.host}
|
||||
w.tracker.add(tc)
|
||||
return tc, buf, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
@@ -10,10 +9,14 @@ import (
|
||||
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||
// they must be tracked and closed explicitly during graceful shutdown.
|
||||
//
|
||||
// Connections are indexed by the request Host so they can be closed
|
||||
// per-domain when a service mapping is removed.
|
||||
//
|
||||
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||
// connections are tracked and automatically deregistered when closed.
|
||||
type HijackTracker struct {
|
||||
conns sync.Map // net.Conn → struct{}
|
||||
mu sync.Mutex
|
||||
conns map[*trackedConn]struct{}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||
@@ -21,21 +24,73 @@ type HijackTracker struct {
|
||||
// tracker when closed. This should be the outermost middleware in the chain.
|
||||
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||
next.ServeHTTP(&trackingWriter{
|
||||
ResponseWriter: w,
|
||||
tracker: t,
|
||||
host: hostOnly(r.Host),
|
||||
}, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all tracked hijacked connections and returns the number
|
||||
// of connections that were closed.
|
||||
// CloseAll closes all tracked hijacked connections and returns the count.
|
||||
func (t *HijackTracker) CloseAll() int {
|
||||
var count int
|
||||
t.conns.Range(func(key, _ any) bool {
|
||||
if conn, ok := key.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
count++
|
||||
}
|
||||
t.conns.Delete(key)
|
||||
return true
|
||||
})
|
||||
return count
|
||||
t.mu.Lock()
|
||||
conns := t.conns
|
||||
t.conns = nil
|
||||
t.mu.Unlock()
|
||||
|
||||
for tc := range conns {
|
||||
_ = tc.Conn.Close()
|
||||
}
|
||||
return len(conns)
|
||||
}
|
||||
|
||||
// CloseByHost closes all tracked hijacked connections for the given host
|
||||
// and returns the number of connections closed.
|
||||
func (t *HijackTracker) CloseByHost(host string) int {
|
||||
host = hostOnly(host)
|
||||
t.mu.Lock()
|
||||
var toClose []*trackedConn
|
||||
for tc := range t.conns {
|
||||
if tc.host == host {
|
||||
toClose = append(toClose, tc)
|
||||
}
|
||||
}
|
||||
for _, tc := range toClose {
|
||||
delete(t.conns, tc)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
for _, tc := range toClose {
|
||||
_ = tc.Conn.Close()
|
||||
}
|
||||
return len(toClose)
|
||||
}
|
||||
|
||||
func (t *HijackTracker) add(tc *trackedConn) {
|
||||
t.mu.Lock()
|
||||
if t.conns == nil {
|
||||
t.conns = make(map[*trackedConn]struct{})
|
||||
}
|
||||
t.conns[tc] = struct{}{}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
func (t *HijackTracker) remove(tc *trackedConn) {
|
||||
t.mu.Lock()
|
||||
delete(t.conns, tc)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// hostOnly strips the port from a host:port string.
|
||||
func hostOnly(hostport string) string {
|
||||
for i := len(hostport) - 1; i >= 0; i-- {
|
||||
if hostport[i] == ':' {
|
||||
return hostport[:i]
|
||||
}
|
||||
if hostport[i] < '0' || hostport[i] > '9' {
|
||||
return hostport
|
||||
}
|
||||
}
|
||||
return hostport
|
||||
}
|
||||
|
||||
142
proxy/internal/conntrack/hijacked_test.go
Normal file
142
proxy/internal/conntrack/hijacked_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeHijackWriter implements http.ResponseWriter and http.Hijacker for testing.
|
||||
type fakeHijackWriter struct {
|
||||
http.ResponseWriter
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (f *fakeHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
||||
return f.conn, rw, nil
|
||||
}
|
||||
|
||||
func TestCloseByHost(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
// Simulate hijacking two connections for different hosts.
|
||||
connA1, connA2 := net.Pipe()
|
||||
defer connA2.Close()
|
||||
connB1, connB2 := net.Pipe()
|
||||
defer connB2.Close()
|
||||
|
||||
twA := &trackingWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
tracker: &tracker,
|
||||
host: "a.example.com",
|
||||
}
|
||||
twB := &trackingWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
tracker: &tracker,
|
||||
host: "b.example.com",
|
||||
}
|
||||
|
||||
// Use fakeHijackWriter to provide the Hijack method.
|
||||
twA.ResponseWriter = &fakeHijackWriter{ResponseWriter: twA.ResponseWriter, conn: connA1}
|
||||
twB.ResponseWriter = &fakeHijackWriter{ResponseWriter: twB.ResponseWriter, conn: connB1}
|
||||
|
||||
_, _, err := twA.Hijack()
|
||||
require.NoError(t, err)
|
||||
_, _, err = twB.Hijack()
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 2, len(tracker.conns), "should track 2 connections")
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Close only host A.
|
||||
n := tracker.CloseByHost("a.example.com")
|
||||
assert.Equal(t, 1, n, "should close 1 connection for host A")
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 1, len(tracker.conns), "should have 1 remaining connection")
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Verify host A's conn is actually closed.
|
||||
buf := make([]byte, 1)
|
||||
_, err = connA2.Read(buf)
|
||||
assert.Error(t, err, "host A pipe should be closed")
|
||||
|
||||
// Host B should still be alive.
|
||||
go func() { _, _ = connB1.Write([]byte("x")) }()
|
||||
|
||||
// Close all remaining.
|
||||
n = tracker.CloseAll()
|
||||
assert.Equal(t, 1, n, "should close remaining 1 connection")
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 0, len(tracker.conns), "should have 0 connections after CloseAll")
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestCloseAll(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
for range 5 {
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "test.com"}
|
||||
tracker.add(tc)
|
||||
}
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 5, len(tracker.conns))
|
||||
tracker.mu.Unlock()
|
||||
|
||||
n := tracker.CloseAll()
|
||||
assert.Equal(t, 5, n)
|
||||
|
||||
// Double CloseAll is safe.
|
||||
n = tracker.CloseAll()
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestTrackedConn_AutoDeregister(t *testing.T) {
|
||||
var tracker HijackTracker
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
defer c2.Close()
|
||||
|
||||
tc := &trackedConn{Conn: c1, tracker: &tracker, host: "auto.com"}
|
||||
tracker.add(tc)
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 1, len(tracker.conns))
|
||||
tracker.mu.Unlock()
|
||||
|
||||
// Close the tracked conn: should auto-deregister.
|
||||
require.NoError(t, tc.Close())
|
||||
|
||||
tracker.mu.Lock()
|
||||
assert.Equal(t, 0, len(tracker.conns), "should auto-deregister on close")
|
||||
tracker.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestHostOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"example.com:443", "example.com"},
|
||||
{"example.com", "example.com"},
|
||||
{"127.0.0.1:8080", "127.0.0.1"},
|
||||
{"[::1]:443", "[::1]"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, hostOnly(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func (c *Client) printClients(data map[string]any) {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "SERVICES", "HAS CLIENT")
|
||||
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
|
||||
|
||||
for _, item := range clients {
|
||||
@@ -166,7 +166,7 @@ func (c *Client) printClientRow(item any) {
|
||||
return
|
||||
}
|
||||
|
||||
domains := c.extractDomains(client)
|
||||
services := c.extractServiceKeys(client)
|
||||
hasClient := "no"
|
||||
if hc, ok := client["has_client"].(bool); ok && hc {
|
||||
hasClient = "yes"
|
||||
@@ -175,20 +175,20 @@ func (c *Client) printClientRow(item any) {
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
|
||||
client["account_id"],
|
||||
client["age"],
|
||||
domains,
|
||||
services,
|
||||
hasClient,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Client) extractDomains(client map[string]any) string {
|
||||
d, ok := client["domains"].([]any)
|
||||
func (c *Client) extractServiceKeys(client map[string]any) string {
|
||||
d, ok := client["service_keys"].([]any)
|
||||
if !ok || len(d) == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
parts := make([]string, len(d))
|
||||
for i, domain := range d {
|
||||
parts[i] = fmt.Sprint(domain)
|
||||
for i, key := range d {
|
||||
parts[i] = fmt.Sprint(key)
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ type indexData struct {
|
||||
Version string
|
||||
Uptime string
|
||||
ClientCount int
|
||||
TotalDomains int
|
||||
TotalServices int
|
||||
CertsTotal int
|
||||
CertsReady int
|
||||
CertsPending int
|
||||
@@ -202,7 +202,7 @@ type indexData struct {
|
||||
|
||||
type clientData struct {
|
||||
AccountID string
|
||||
Domains string
|
||||
Services string
|
||||
Age string
|
||||
Status string
|
||||
}
|
||||
@@ -211,9 +211,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
totalDomains := 0
|
||||
totalServices := 0
|
||||
for _, info := range clients {
|
||||
totalDomains += info.DomainCount
|
||||
totalServices += info.ServiceCount
|
||||
}
|
||||
|
||||
var certsTotal, certsReady, certsPending, certsFailed int
|
||||
@@ -234,24 +234,24 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"domain_count": info.DomainCount,
|
||||
"domains": info.Domains,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"total_domains": totalDomains,
|
||||
"certs_total": certsTotal,
|
||||
"certs_ready": certsReady,
|
||||
"certs_pending": certsPending,
|
||||
"certs_failed": certsFailed,
|
||||
"clients": clientsJSON,
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"total_services": totalServices,
|
||||
"certs_total": certsTotal,
|
||||
"certs_ready": certsReady,
|
||||
"certs_pending": certsPending,
|
||||
"certs_failed": certsFailed,
|
||||
"clients": clientsJSON,
|
||||
}
|
||||
if len(certsPendingDomains) > 0 {
|
||||
resp["certs_pending_domains"] = certsPendingDomains
|
||||
@@ -278,7 +278,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
Version: version.NetbirdVersion(),
|
||||
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||
ClientCount: len(clients),
|
||||
TotalDomains: totalDomains,
|
||||
TotalServices: totalServices,
|
||||
CertsTotal: certsTotal,
|
||||
CertsReady: certsReady,
|
||||
CertsPending: certsPending,
|
||||
@@ -291,9 +291,9 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
domains := info.Domains.SafeString()
|
||||
if domains == "" {
|
||||
domains = "-"
|
||||
services := strings.Join(info.ServiceKeys, ", ")
|
||||
if services == "" {
|
||||
services = "-"
|
||||
}
|
||||
status := "No client"
|
||||
if info.HasClient {
|
||||
@@ -301,7 +301,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
|
||||
}
|
||||
data.Clients = append(data.Clients, clientData{
|
||||
AccountID: string(info.AccountID),
|
||||
Domains: domains,
|
||||
Services: services,
|
||||
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
Status: status,
|
||||
})
|
||||
@@ -324,12 +324,12 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"domain_count": info.DomainCount,
|
||||
"domains": info.Domains,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
@@ -347,9 +347,9 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
domains := info.Domains.SafeString()
|
||||
if domains == "" {
|
||||
domains = "-"
|
||||
services := strings.Join(info.ServiceKeys, ", ")
|
||||
if services == "" {
|
||||
services = "-"
|
||||
}
|
||||
status := "No client"
|
||||
if info.HasClient {
|
||||
@@ -357,7 +357,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
}
|
||||
data.Clients = append(data.Clients, clientData{
|
||||
AccountID: string(info.AccountID),
|
||||
Domains: domains,
|
||||
Services: services,
|
||||
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
Status: status,
|
||||
})
|
||||
|
||||
@@ -12,14 +12,14 @@
|
||||
<table>
|
||||
<tr>
|
||||
<th>Account ID</th>
|
||||
<th>Domains</th>
|
||||
<th>Services</th>
|
||||
<th>Age</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
{{range .Clients}}
|
||||
<tr>
|
||||
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||
<td>{{.Domains}}</td>
|
||||
<td>{{.Services}}</td>
|
||||
<td>{{.Age}}</td>
|
||||
<td>{{.Status}}</td>
|
||||
</tr>
|
||||
|
||||
@@ -27,19 +27,19 @@
|
||||
<ul>{{range .CertsFailedDomains}}<li>{{.Domain}}: {{.Error}}</li>{{end}}</ul>
|
||||
</details>
|
||||
{{end}}
|
||||
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
|
||||
<h2>Clients ({{.ClientCount}}) | Services ({{.TotalServices}})</h2>
|
||||
{{if .Clients}}
|
||||
<table>
|
||||
<tr>
|
||||
<th>Account ID</th>
|
||||
<th>Domains</th>
|
||||
<th>Services</th>
|
||||
<th>Age</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
{{range .Clients}}
|
||||
<tr>
|
||||
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||
<td>{{.Domains}}</td>
|
||||
<td>{{.Services}}</td>
|
||||
<td>{{.Age}}</td>
|
||||
<td>{{.Status}}</td>
|
||||
</tr>
|
||||
|
||||
69
proxy/internal/metrics/l4_metrics_test.go
Normal file
69
proxy/internal/metrics/l4_metrics_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package metrics_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
promexporter "go.opentelemetry.io/otel/exporters/prometheus"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func newTestMetrics(t *testing.T) *metrics.Metrics {
|
||||
t.Helper()
|
||||
|
||||
exporter, err := promexporter.New()
|
||||
if err != nil {
|
||||
t.Fatalf("create prometheus exporter: %v", err)
|
||||
}
|
||||
|
||||
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter))
|
||||
pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath()
|
||||
meter := provider.Meter(pkg)
|
||||
|
||||
m, err := metrics.New(context.Background(), meter)
|
||||
if err != nil {
|
||||
t.Fatalf("create metrics: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestL4ServiceGauge(t *testing.T) {
|
||||
m := newTestMetrics(t)
|
||||
|
||||
m.L4ServiceAdded(types.ServiceModeTCP)
|
||||
m.L4ServiceAdded(types.ServiceModeTCP)
|
||||
m.L4ServiceAdded(types.ServiceModeUDP)
|
||||
m.L4ServiceRemoved(types.ServiceModeTCP)
|
||||
}
|
||||
|
||||
func TestTCPRelayMetrics(t *testing.T) {
|
||||
m := newTestMetrics(t)
|
||||
|
||||
acct := types.AccountID("acct-1")
|
||||
|
||||
m.TCPRelayStarted(acct)
|
||||
m.TCPRelayStarted(acct)
|
||||
m.TCPRelayEnded(acct, 10*time.Second, 1000, 500)
|
||||
m.TCPRelayDialError(acct)
|
||||
m.TCPRelayRejected(acct)
|
||||
}
|
||||
|
||||
func TestUDPSessionMetrics(t *testing.T) {
|
||||
m := newTestMetrics(t)
|
||||
|
||||
acct := types.AccountID("acct-2")
|
||||
|
||||
m.UDPSessionStarted(acct)
|
||||
m.UDPSessionStarted(acct)
|
||||
m.UDPSessionEnded(acct)
|
||||
m.UDPSessionDialError(acct)
|
||||
m.UDPSessionRejected(acct)
|
||||
m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 100)
|
||||
m.UDPPacketRelayed(types.RelayDirectionClientToBackend, 200)
|
||||
m.UDPPacketRelayed(types.RelayDirectionBackendToClient, 150)
|
||||
}
|
||||
@@ -6,12 +6,15 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// Metrics collects OpenTelemetry metrics for the proxy.
|
||||
type Metrics struct {
|
||||
ctx context.Context
|
||||
requestsTotal metric.Int64Counter
|
||||
@@ -22,85 +25,188 @@ type Metrics struct {
|
||||
backendDuration metric.Int64Histogram
|
||||
certificateIssueDuration metric.Int64Histogram
|
||||
|
||||
// L4 service-level metrics.
|
||||
l4Services metric.Int64UpDownCounter
|
||||
|
||||
// L4 TCP connection-level metrics.
|
||||
tcpActiveConns metric.Int64UpDownCounter
|
||||
tcpConnsTotal metric.Int64Counter
|
||||
tcpConnDuration metric.Int64Histogram
|
||||
tcpBytesTotal metric.Int64Counter
|
||||
|
||||
// L4 UDP session-level metrics.
|
||||
udpActiveSess metric.Int64UpDownCounter
|
||||
udpSessionsTotal metric.Int64Counter
|
||||
udpPacketsTotal metric.Int64Counter
|
||||
udpBytesTotal metric.Int64Counter
|
||||
|
||||
mappingsMux sync.Mutex
|
||||
mappingPaths map[string]int
|
||||
}
|
||||
|
||||
// New creates a Metrics instance using the given OpenTelemetry meter.
|
||||
func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
requestsTotal, err := meter.Int64Counter(
|
||||
m := &Metrics{
|
||||
ctx: ctx,
|
||||
mappingPaths: make(map[string]int),
|
||||
}
|
||||
|
||||
if err := m.initHTTPMetrics(meter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.initL4Metrics(meter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Metrics) initHTTPMetrics(meter metric.Meter) error {
|
||||
var err error
|
||||
|
||||
m.requestsTotal, err = meter.Int64Counter(
|
||||
"proxy.http.request.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total number of requests made to the netbird proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
activeRequests, err := meter.Int64UpDownCounter(
|
||||
m.activeRequests, err = meter.Int64UpDownCounter(
|
||||
"proxy.http.active_requests",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Current in-flight requests handled by the netbird proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
configuredDomains, err := meter.Int64UpDownCounter(
|
||||
m.configuredDomains, err = meter.Int64UpDownCounter(
|
||||
"proxy.domains.count",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Current number of domains configured on the netbird proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
totalPaths, err := meter.Int64UpDownCounter(
|
||||
m.totalPaths, err = meter.Int64UpDownCounter(
|
||||
"proxy.paths.count",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total number of paths configured on the netbird proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
requestDuration, err := meter.Int64Histogram(
|
||||
m.requestDuration, err = meter.Int64Histogram(
|
||||
"proxy.http.request.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of requests made to the netbird proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
backendDuration, err := meter.Int64Histogram(
|
||||
m.backendDuration, err = meter.Int64Histogram(
|
||||
"proxy.backend.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of peer round trip time from the netbird proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
certificateIssueDuration, err := meter.Int64Histogram(
|
||||
m.certificateIssueDuration, err = meter.Int64Histogram(
|
||||
"proxy.certificate.issue.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of ACME certificate issuance"),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Metrics) initL4Metrics(meter metric.Meter) error {
|
||||
var err error
|
||||
|
||||
m.l4Services, err = meter.Int64UpDownCounter(
|
||||
"proxy.l4.services.count",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Current number of configured L4 services (TCP/TLS/UDP) by mode"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return &Metrics{
|
||||
ctx: ctx,
|
||||
requestsTotal: requestsTotal,
|
||||
activeRequests: activeRequests,
|
||||
configuredDomains: configuredDomains,
|
||||
totalPaths: totalPaths,
|
||||
requestDuration: requestDuration,
|
||||
backendDuration: backendDuration,
|
||||
certificateIssueDuration: certificateIssueDuration,
|
||||
mappingPaths: make(map[string]int),
|
||||
}, nil
|
||||
m.tcpActiveConns, err = meter.Int64UpDownCounter(
|
||||
"proxy.tcp.active_connections",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Current number of active TCP/TLS relay connections"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.tcpConnsTotal, err = meter.Int64Counter(
|
||||
"proxy.tcp.connections.total",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total TCP/TLS relay connections by result and account"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.tcpConnDuration, err = meter.Int64Histogram(
|
||||
"proxy.tcp.connection.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of TCP/TLS relay connections"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.tcpBytesTotal, err = meter.Int64Counter(
|
||||
"proxy.tcp.bytes.total",
|
||||
metric.WithUnit("bytes"),
|
||||
metric.WithDescription("Total bytes transferred through TCP/TLS relay by direction"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.udpActiveSess, err = meter.Int64UpDownCounter(
|
||||
"proxy.udp.active_sessions",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Current number of active UDP relay sessions"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.udpSessionsTotal, err = meter.Int64Counter(
|
||||
"proxy.udp.sessions.total",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total UDP relay sessions by result and account"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.udpPacketsTotal, err = meter.Int64Counter(
|
||||
"proxy.udp.packets.total",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total UDP packets relayed by direction"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.udpBytesTotal, err = meter.Int64Counter(
|
||||
"proxy.udp.bytes.total",
|
||||
metric.WithUnit("bytes"),
|
||||
metric.WithDescription("Total bytes transferred through UDP relay by direction"),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
@@ -120,6 +226,13 @@ func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||
return size, err
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter so http.ResponseController
|
||||
// can reach through to the original writer for Hijack/Flush operations.
|
||||
func (w *responseInterceptor) Unwrap() http.ResponseWriter {
|
||||
return w.PassthroughWriter
|
||||
}
|
||||
|
||||
// Middleware wraps an HTTP handler with request metrics.
|
||||
func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
m.requestsTotal.Add(m.ctx, 1)
|
||||
@@ -144,6 +257,7 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
// RoundTripper wraps an http.RoundTripper with backend duration metrics.
|
||||
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
|
||||
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
start := time.Now()
|
||||
@@ -156,6 +270,7 @@ func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
|
||||
})
|
||||
}
|
||||
|
||||
// AddMapping records that a domain mapping was added.
|
||||
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
|
||||
m.mappingsMux.Lock()
|
||||
defer m.mappingsMux.Unlock()
|
||||
@@ -175,13 +290,13 @@ func (m *Metrics) AddMapping(mapping proxy.Mapping) {
|
||||
m.mappingPaths[mapping.Host] = newPathCount
|
||||
}
|
||||
|
||||
// RemoveMapping records that a domain mapping was removed.
|
||||
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
|
||||
m.mappingsMux.Lock()
|
||||
defer m.mappingsMux.Unlock()
|
||||
|
||||
oldPathCount, exists := m.mappingPaths[mapping.Host]
|
||||
if !exists {
|
||||
// Nothing to remove
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,3 +310,80 @@ func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
|
||||
func (m *Metrics) RecordCertificateIssuance(duration time.Duration) {
|
||||
m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
// L4ServiceAdded increments the L4 service gauge for the given mode.
|
||||
func (m *Metrics) L4ServiceAdded(mode types.ServiceMode) {
|
||||
m.l4Services.Add(m.ctx, 1, metric.WithAttributes(attribute.String("mode", string(mode))))
|
||||
}
|
||||
|
||||
// L4ServiceRemoved decrements the L4 service gauge for the given mode.
|
||||
func (m *Metrics) L4ServiceRemoved(mode types.ServiceMode) {
|
||||
m.l4Services.Add(m.ctx, -1, metric.WithAttributes(attribute.String("mode", string(mode))))
|
||||
}
|
||||
|
||||
// TCPRelayStarted records a new TCP relay connection starting.
|
||||
func (m *Metrics) TCPRelayStarted(accountID types.AccountID) {
|
||||
acct := attribute.String("account_id", string(accountID))
|
||||
m.tcpActiveConns.Add(m.ctx, 1, metric.WithAttributes(acct))
|
||||
m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success")))
|
||||
}
|
||||
|
||||
// TCPRelayEnded records a TCP relay connection ending and accumulates bytes and duration.
|
||||
func (m *Metrics) TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64) {
|
||||
acct := attribute.String("account_id", string(accountID))
|
||||
m.tcpActiveConns.Add(m.ctx, -1, metric.WithAttributes(acct))
|
||||
m.tcpConnDuration.Record(m.ctx, duration.Milliseconds(), metric.WithAttributes(acct))
|
||||
m.tcpBytesTotal.Add(m.ctx, srcToDst, metric.WithAttributes(attribute.String("direction", "client_to_backend")))
|
||||
m.tcpBytesTotal.Add(m.ctx, dstToSrc, metric.WithAttributes(attribute.String("direction", "backend_to_client")))
|
||||
}
|
||||
|
||||
// TCPRelayDialError records a dial failure for a TCP relay.
|
||||
func (m *Metrics) TCPRelayDialError(accountID types.AccountID) {
|
||||
m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(
|
||||
attribute.String("account_id", string(accountID)),
|
||||
attribute.String("result", "dial_error"),
|
||||
))
|
||||
}
|
||||
|
||||
// TCPRelayRejected records a rejected TCP relay (semaphore full).
|
||||
func (m *Metrics) TCPRelayRejected(accountID types.AccountID) {
|
||||
m.tcpConnsTotal.Add(m.ctx, 1, metric.WithAttributes(
|
||||
attribute.String("account_id", string(accountID)),
|
||||
attribute.String("result", "rejected"),
|
||||
))
|
||||
}
|
||||
|
||||
// UDPSessionStarted records a new UDP session starting.
|
||||
func (m *Metrics) UDPSessionStarted(accountID types.AccountID) {
|
||||
acct := attribute.String("account_id", string(accountID))
|
||||
m.udpActiveSess.Add(m.ctx, 1, metric.WithAttributes(acct))
|
||||
m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(acct, attribute.String("result", "success")))
|
||||
}
|
||||
|
||||
// UDPSessionEnded records a UDP session ending.
|
||||
func (m *Metrics) UDPSessionEnded(accountID types.AccountID) {
|
||||
m.udpActiveSess.Add(m.ctx, -1, metric.WithAttributes(attribute.String("account_id", string(accountID))))
|
||||
}
|
||||
|
||||
// UDPSessionDialError records a dial failure for a UDP session.
|
||||
func (m *Metrics) UDPSessionDialError(accountID types.AccountID) {
|
||||
m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(
|
||||
attribute.String("account_id", string(accountID)),
|
||||
attribute.String("result", "dial_error"),
|
||||
))
|
||||
}
|
||||
|
||||
// UDPSessionRejected records a rejected UDP session (limit or rate limited).
|
||||
func (m *Metrics) UDPSessionRejected(accountID types.AccountID) {
|
||||
m.udpSessionsTotal.Add(m.ctx, 1, metric.WithAttributes(
|
||||
attribute.String("account_id", string(accountID)),
|
||||
attribute.String("result", "rejected"),
|
||||
))
|
||||
}
|
||||
|
||||
// UDPPacketRelayed records a packet relayed in the given direction with its size in bytes.
|
||||
func (m *Metrics) UDPPacketRelayed(direction types.RelayDirection, bytes int) {
|
||||
dir := attribute.String("direction", string(direction))
|
||||
m.udpPacketsTotal.Add(m.ctx, 1, metric.WithAttributes(dir))
|
||||
m.udpBytesTotal.Add(m.ctx, int64(bytes), metric.WithAttributes(dir))
|
||||
}
|
||||
|
||||
40
proxy/internal/netutil/errors.go
Normal file
40
proxy/internal/netutil/errors.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// ValidatePort converts an int32 proto port to uint16, returning an error
|
||||
// if the value is out of the valid 1–65535 range.
|
||||
func ValidatePort(port int32) (uint16, error) {
|
||||
if port <= 0 || port > math.MaxUint16 {
|
||||
return 0, fmt.Errorf("invalid port %d: must be 1–65535", port)
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// IsExpectedError returns true for errors that are normal during
|
||||
// connection teardown and should not be logged as warnings.
|
||||
func IsExpectedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, syscall.ECONNRESET) ||
|
||||
errors.Is(err, syscall.EPIPE) ||
|
||||
errors.Is(err, syscall.ECONNABORTED)
|
||||
}
|
||||
|
||||
// IsTimeout checks whether the error is a network timeout.
|
||||
func IsTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
92
proxy/internal/netutil/errors_test.go
Normal file
92
proxy/internal/netutil/errors_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidatePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port int32
|
||||
want uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid min", 1, 1, false},
|
||||
{"valid mid", 8080, 8080, false},
|
||||
{"valid max", 65535, 65535, false},
|
||||
{"zero", 0, 0, true},
|
||||
{"negative", -1, 0, true},
|
||||
{"too large", 65536, 0, true},
|
||||
{"way too large", 100000, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ValidatePort(tt.port)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsExpectedError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"net.ErrClosed", net.ErrClosed, true},
|
||||
{"context.Canceled", context.Canceled, true},
|
||||
{"io.EOF", io.EOF, true},
|
||||
{"ECONNRESET", syscall.ECONNRESET, true},
|
||||
{"EPIPE", syscall.EPIPE, true},
|
||||
{"ECONNABORTED", syscall.ECONNABORTED, true},
|
||||
{"wrapped expected", fmt.Errorf("wrap: %w", net.ErrClosed), true},
|
||||
{"unexpected EOF", io.ErrUnexpectedEOF, false},
|
||||
{"generic error", errors.New("something"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, IsExpectedError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type timeoutErr struct{ timeout bool }
|
||||
|
||||
func (e *timeoutErr) Error() string { return "timeout" }
|
||||
func (e *timeoutErr) Timeout() bool { return e.timeout }
|
||||
func (e *timeoutErr) Temporary() bool { return false }
|
||||
|
||||
func TestIsTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"net timeout", &timeoutErr{timeout: true}, true},
|
||||
{"net non-timeout", &timeoutErr{timeout: false}, false},
|
||||
{"wrapped timeout", fmt.Errorf("wrap: %w", &timeoutErr{timeout: true}), true},
|
||||
{"generic error", errors.New("not a timeout"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, IsTimeout(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
@@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string {
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
RequestID string
|
||||
ServiceId string
|
||||
ServiceId types.ServiceID
|
||||
AccountId types.AccountID
|
||||
Origin ResponseOrigin
|
||||
ClientIP string
|
||||
ClientIP netip.Addr
|
||||
UserID string
|
||||
AuthMethod string
|
||||
}
|
||||
@@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string {
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
func (c *CapturedData) SetServiceId(serviceId string) {
|
||||
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ServiceId = serviceId
|
||||
}
|
||||
|
||||
// GetServiceId safely gets the service ID
|
||||
func (c *CapturedData) GetServiceId() string {
|
||||
func (c *CapturedData) GetServiceId() types.ServiceID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ServiceId
|
||||
@@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin {
|
||||
}
|
||||
|
||||
// SetClientIP safely sets the resolved client IP.
|
||||
func (c *CapturedData) SetClientIP(ip string) {
|
||||
func (c *CapturedData) SetClientIP(ip netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ClientIP = ip
|
||||
}
|
||||
|
||||
// GetClientIP safely gets the resolved client IP.
|
||||
func (c *CapturedData) GetClientIP() string {
|
||||
func (c *CapturedData) GetClientIP() netip.Addr {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ClientIP
|
||||
@@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
|
||||
return data
|
||||
}
|
||||
|
||||
func withServiceId(ctx context.Context, serviceId string) context.Context {
|
||||
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
|
||||
return context.WithValue(ctx, serviceIdKey, serviceId)
|
||||
}
|
||||
|
||||
func ServiceIdFromContext(ctx context.Context) string {
|
||||
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
|
||||
v := ctx.Value(serviceIdKey)
|
||||
serviceId, ok := v.(string)
|
||||
serviceId, ok := v.(types.ServiceID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
func BenchmarkServeHTTP(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
ID: types.ServiceID(rand.Text()),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
@@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
|
||||
target = id
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: id,
|
||||
ID: types.ServiceID(id),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: host,
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
@@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
|
||||
}
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
ID: types.ServiceID(rand.Text()),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: paths,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = roundtrip.WithSkipTLSVerify(ctx)
|
||||
}
|
||||
if pt.RequestTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
|
||||
defer cancel()
|
||||
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
|
||||
}
|
||||
|
||||
rewriteMatchedPath := result.matchedPath
|
||||
@@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
clientIP := extractHostIP(r.In.RemoteAddr)
|
||||
|
||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||
if isTrustedAddr(clientIP, p.trustedProxies) {
|
||||
p.setTrustedForwardingHeaders(r, clientIP)
|
||||
} else {
|
||||
p.setUntrustedForwardingHeaders(r, clientIP)
|
||||
@@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string {
|
||||
// setTrustedForwardingHeaders appends to the existing forwarding header chain
|
||||
// and preserves upstream-provided headers when the direct connection is from
|
||||
// a trusted proxy.
|
||||
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
|
||||
ipStr := clientIP.String()
|
||||
|
||||
// Append the direct connection IP to the existing X-Forwarded-For chain.
|
||||
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
|
||||
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-For", ipStr)
|
||||
}
|
||||
|
||||
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
|
||||
@@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
|
||||
r.Out.Header.Set("X-Real-IP", realIP)
|
||||
} else {
|
||||
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
|
||||
r.Out.Header.Set("X-Real-IP", resolved)
|
||||
r.Out.Header.Set("X-Real-IP", resolved.String())
|
||||
}
|
||||
|
||||
// Preserve upstream X-Forwarded-Host if present.
|
||||
@@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
|
||||
// sets them fresh based on the direct connection. This is the default
|
||||
// behavior when no trusted proxies are configured or the direct connection
|
||||
// is from an untrusted source.
|
||||
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
|
||||
ipStr := clientIP.String()
|
||||
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Real-IP", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-For", ipStr)
|
||||
r.Out.Header.Set("X-Real-IP", ipStr)
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
r.Out.Header.Set("X-Forwarded-Proto", proto)
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
|
||||
@@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
|
||||
// which is always in host:port format.
|
||||
func extractClientIP(remoteAddr string) string {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// extractForwardedPort returns the port from the Host header if present,
|
||||
// otherwise defaults to the standard port for the resolved protocol.
|
||||
func extractForwardedPort(host, resolvedProto string) string {
|
||||
@@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
web.ServeErrorPage(w, r, code, title, message, requestID, status)
|
||||
}
|
||||
|
||||
// getClientIP retrieves the resolved client IP from context.
|
||||
// getClientIP retrieves the resolved client IP string from context.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
|
||||
return capturedData.GetClientIP()
|
||||
if ip := capturedData.GetClientIP(); ip.IsValid() {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
func TestExtractHostIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
expected netip.Addr
|
||||
}{
|
||||
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
|
||||
{"IPv6 with port", "[::1]:12345", "::1"},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
|
||||
{"IPv6 without brackets fallback", "::1", "::1"},
|
||||
{"empty string fallback", "", ""},
|
||||
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
|
||||
{"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")},
|
||||
{"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")},
|
||||
{"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")},
|
||||
{"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")},
|
||||
{"empty string fallback", "", netip.Addr{}},
|
||||
{"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
|
||||
assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,8 +30,9 @@ type PathTarget struct {
|
||||
CustomHeaders map[string]string
|
||||
}
|
||||
|
||||
// Mapping describes how a domain is routed by the HTTP reverse proxy.
|
||||
type Mapping struct {
|
||||
ID string
|
||||
ID types.ServiceID
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*PathTarget
|
||||
@@ -42,7 +43,7 @@ type Mapping struct {
|
||||
type targetResult struct {
|
||||
target *PathTarget
|
||||
matchedPath string
|
||||
serviceID string
|
||||
serviceID types.ServiceID
|
||||
accountID types.AccountID
|
||||
passHostHeader bool
|
||||
rewriteRedirects bool
|
||||
@@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
p.mappings[m.Host] = m
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) {
|
||||
// RemoveMapping removes the mapping for the given host and reports whether it existed.
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) bool {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
if _, ok := p.mappings[m.Host]; !ok {
|
||||
return false
|
||||
}
|
||||
delete(p.mappings, m.Host)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -7,21 +7,11 @@ import (
|
||||
|
||||
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
|
||||
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
if len(trusted) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
if err != nil || len(trusted) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isTrustedAddr(addr.Unmap(), trusted)
|
||||
}
|
||||
|
||||
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
|
||||
@@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
//
|
||||
// If the trusted list is empty or remoteAddr is not trusted, it returns the
|
||||
// remoteAddr IP directly (ignoring any forwarding headers).
|
||||
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
|
||||
remoteIP := extractClientIP(remoteAddr)
|
||||
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr {
|
||||
remoteIP := extractHostIP(remoteAddr)
|
||||
|
||||
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
|
||||
if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
@@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
if !IsTrustedProxy(ip, trusted) {
|
||||
return ip
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
if !isTrustedAddr(addr, trusted) {
|
||||
return addr
|
||||
}
|
||||
}
|
||||
|
||||
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||
return first
|
||||
if addr, err := netip.ParseAddr(first); err == nil {
|
||||
return addr.Unmap()
|
||||
}
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// extractHostIP parses the IP from a host:port string and returns it unmapped.
|
||||
func extractHostIP(hostPort string) netip.Addr {
|
||||
if ap, err := netip.ParseAddrPort(hostPort); err == nil {
|
||||
return ap.Addr().Unmap()
|
||||
}
|
||||
if addr, err := netip.ParseAddr(hostPort); err == nil {
|
||||
return addr.Unmap()
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// isTrustedAddr checks if the given address falls within any of the trusted prefixes.
|
||||
func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool {
|
||||
if !addr.IsValid() {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) {
|
||||
remoteAddr string
|
||||
xff string
|
||||
trusted []netip.Prefix
|
||||
want string
|
||||
want netip.Addr
|
||||
}{
|
||||
{
|
||||
name: "empty trusted list returns RemoteAddr",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4",
|
||||
trusted: nil,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "untrusted RemoteAddr ignores XFF",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4, 10.0.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with single client in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr walks past trusted entries in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.1",
|
||||
want: netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
{
|
||||
name: "all XFF IPs trusted returns leftmost",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.2",
|
||||
want: netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
{
|
||||
name: "XFF with whitespace",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: " 203.0.113.50 , 10.0.0.2 ",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "XFF with empty segments",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50,,10.0.0.2",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "multi-hop with mixed trust",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr without port",
|
||||
remoteAddr: "10.0.0.1",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -14,11 +15,12 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -26,7 +28,22 @@ import (
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey = string
|
||||
type backendKey string
|
||||
|
||||
// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service)
|
||||
// that holds a reference to an embedded NetBird client. Callers should use the
|
||||
// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions.
|
||||
type ServiceKey string
|
||||
|
||||
// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service.
|
||||
func DomainServiceKey(domain string) ServiceKey {
|
||||
return ServiceKey("domain:" + domain)
|
||||
}
|
||||
|
||||
// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP).
|
||||
func L4ServiceKey(id types.ServiceID) ServiceKey {
|
||||
return ServiceKey("l4:" + id)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
@@ -39,24 +56,24 @@ var (
|
||||
ErrTooManyInflight = errors.New("too many in-flight requests")
|
||||
)
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
type domainInfo struct {
|
||||
serviceID string
|
||||
// serviceInfo holds metadata about a registered service.
|
||||
type serviceInfo struct {
|
||||
serviceID types.ServiceID
|
||||
}
|
||||
|
||||
type domainNotification struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
type serviceNotification struct {
|
||||
key ServiceKey
|
||||
serviceID types.ServiceID
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
// clientEntry holds an embedded NetBird client and tracks which services use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
// insecureTransport is a clone of transport with TLS verification disabled,
|
||||
// used when per-target skip_tls_verify is set.
|
||||
insecureTransport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
services map[ServiceKey]serviceInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
@@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
|
||||
// ClientConfig holds configuration for the embedded NetBird client.
|
||||
type ClientConfig struct {
|
||||
MgmtAddr string
|
||||
WGPort int
|
||||
WGPort uint16
|
||||
PreSharedKey string
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
|
||||
NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error
|
||||
}
|
||||
|
||||
type managementClient interface {
|
||||
@@ -107,7 +124,7 @@ type managementClient interface {
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
|
||||
type NetBird struct {
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
@@ -124,11 +141,11 @@ type NetBird struct {
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
type ClientDebugInfo struct {
|
||||
AccountID types.AccountID
|
||||
DomainCount int
|
||||
Domains domain.List
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
AccountID types.AccountID
|
||||
ServiceCount int
|
||||
ServiceKeys []string
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
@@ -137,37 +154,37 @@ type accountIDContextKey struct{}
|
||||
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
|
||||
type skipTLSVerifyContextKey struct{}
|
||||
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// AddPeer registers a service for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple domains can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
|
||||
// Multiple services can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
// Client already exists for this account, just register the domain
|
||||
entry.domains[d] = domainInfo{serviceID: serviceID}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("registered domain with existing client")
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
// If client is already started, notify this domain as connected immediately
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
|
||||
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
@@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
@@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
@@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: serviceID,
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
WireguardPublicKey: publicKey.String(),
|
||||
@@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
|
||||
// Create embedded NetBird client with the generated private key.
|
||||
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
||||
wgPort := int(n.clientCfg.WGPort)
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &n.clientCfg.WGPort,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
transport := &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
DialContext: dialWithTimeout(client.DialContext),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
@@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
|
||||
return &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
services: map[ServiceKey]serviceInfo{key: si},
|
||||
transport: transport,
|
||||
insecureTransport: insecureTransport,
|
||||
createdAt: time.Now(),
|
||||
@@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered domains on success.
|
||||
// runClientStartup starts the client and notifies registered services on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
@@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect domains to notify outside the lock.
|
||||
// Mark client as started and collect services to notify outside the lock.
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
var domainsToNotify []domainNotification
|
||||
var toNotify []serviceNotification
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
|
||||
for key, info := range entry.services {
|
||||
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
@@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
for _, dn := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
|
||||
for _, sn := range toNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
"account_id": accountID,
|
||||
"service_key": sn.key,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
"account_id": accountID,
|
||||
"service_key": sn.key,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
// when no domains are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get domain info before deleting
|
||||
domInfo, domainExists := entry.domains[d]
|
||||
if !domainExists {
|
||||
si, svcExists := entry.services[key]
|
||||
if !svcExists {
|
||||
n.clientsMux.Unlock()
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("remove peer: domain not registered")
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("remove peer: service not registered")
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(entry.domains, d)
|
||||
|
||||
// If there are still domains using this client, keep it running
|
||||
if len(entry.domains) > 0 {
|
||||
n.clientsMux.Unlock()
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"remaining_domains": len(entry.domains),
|
||||
}).Debug("unregistered domain, client still in use")
|
||||
|
||||
// Notify this domain as disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
"remaining_services": len(entry.services),
|
||||
}).Debug("unregistered service, client still in use")
|
||||
}
|
||||
|
||||
// No more domains using this client, stop it
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Info("stopping client, no more domains")
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
insecureTransport := entry.insecureTransport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify disconnection before stopping
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil {
|
||||
if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||
// specified in the request context and uses it to dial the backend.
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
release, ok := entry.acquireInflight(backendKey(req.URL.Host))
|
||||
defer release()
|
||||
if !ok {
|
||||
return nil, ErrTooManyInflight
|
||||
@@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||
return exists
|
||||
}
|
||||
|
||||
// DomainCount returns the number of domains registered for the given account.
|
||||
// ServiceCount returns the number of services registered for the given account.
|
||||
// Returns 0 if the account has no client.
|
||||
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||
func (n *NetBird) ServiceCount(accountID types.AccountID) int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
return len(entry.domains)
|
||||
return len(entry.services)
|
||||
}
|
||||
|
||||
// ClientCount returns the total number of active clients.
|
||||
@@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
|
||||
result := make(map[types.AccountID]ClientDebugInfo)
|
||||
for accountID, entry := range n.clients {
|
||||
domains := make(domain.List, 0, len(entry.domains))
|
||||
for d := range entry.domains {
|
||||
domains = append(domains, d)
|
||||
keys := make([]string, 0, len(entry.services))
|
||||
for k := range entry.services {
|
||||
keys = append(keys, string(k))
|
||||
}
|
||||
result[accountID] = ClientDebugInfo{
|
||||
AccountID: accountID,
|
||||
DomainCount: len(entry.domains),
|
||||
Domains: domains,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
AccountID: accountID,
|
||||
ServiceCount: len(entry.services),
|
||||
ServiceKeys: keys,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
}
|
||||
}
|
||||
return result
|
||||
@@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
|
||||
}
|
||||
}
|
||||
|
||||
// dialWithTimeout wraps a DialContext function so that any dial timeout
|
||||
// stored in the context (via types.WithDialTimeout) is applied only to
|
||||
// the connection establishment phase, not the full request lifetime.
|
||||
func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if d, ok := types.DialTimeoutFromContext(ctx); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, d)
|
||||
defer cancel()
|
||||
}
|
||||
return dial(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAccountID adds the account ID to the context.
|
||||
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"sync"
|
||||
@@ -8,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// Simple benchmark for comparison with AddPeer contention.
|
||||
@@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) {
|
||||
target = id
|
||||
}
|
||||
nb.clients[id] = &clientEntry{
|
||||
domains: map[domain.Domain]domainInfo{
|
||||
domain.Domain(rand.Text()): {
|
||||
serviceID: rand.Text(),
|
||||
services: map[ServiceKey]serviceInfo{
|
||||
ServiceKey(rand.Text()): {
|
||||
serviceID: types.ServiceID(rand.Text()),
|
||||
},
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
@@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
target = id
|
||||
}
|
||||
nb.clients[id] = &clientEntry{
|
||||
domains: map[domain.Domain]domainInfo{
|
||||
domain.Domain(rand.Text()): {
|
||||
serviceID: rand.Text(),
|
||||
services: map[ServiceKey]serviceInfo{
|
||||
ServiceKey(rand.Text()): {
|
||||
serviceID: types.ServiceID(rand.Text()),
|
||||
},
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
@@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
}
|
||||
|
||||
// Launch workers that continuously call AddPeer with new random accountIDs.
|
||||
ctx, cancel := context.WithCancel(b.Context())
|
||||
var wg sync.WaitGroup
|
||||
for range addPeerWorkers {
|
||||
wg.Go(func() {
|
||||
for {
|
||||
if err := nb.AddPeer(b.Context(),
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for ctx.Err() == nil {
|
||||
if err := nb.AddPeer(ctx,
|
||||
types.AccountID(rand.Text()),
|
||||
domain.Domain(rand.Text()),
|
||||
ServiceKey(rand.Text()),
|
||||
rand.Text(),
|
||||
rand.Text()); err != nil {
|
||||
b.Log(err)
|
||||
types.ServiceID(rand.Text())); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// Benchmark calling HasClient during AddPeer contention.
|
||||
@@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.StopTimer()
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -27,16 +26,15 @@ type mockStatusNotifier struct {
|
||||
}
|
||||
|
||||
type statusCall struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
domain string
|
||||
accountID types.AccountID
|
||||
serviceID types.ServiceID
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
|
||||
// Initially no client exists.
|
||||
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
|
||||
|
||||
// Add first domain - this should create a new client.
|
||||
// Note: This will fail to actually connect since we use an invalid URL,
|
||||
// but the client entry should still be created.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add first service - this should create a new client.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add first service.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID))
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID))
|
||||
|
||||
// Add second domain for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
// Add second service for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
|
||||
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service")
|
||||
|
||||
// Add third domain.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
// Add third service.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
|
||||
assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service")
|
||||
|
||||
// Still only one client.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
@@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
|
||||
// Add domain for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add service for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add domain for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
|
||||
// Add service for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both accounts should have their own clients.
|
||||
assert.True(t, nb.HasClient(account1), "account1 should have client")
|
||||
assert.True(t, nb.HasClient(account2), "account2 should have client")
|
||||
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1")
|
||||
assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add multiple domains.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add multiple services.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID))
|
||||
assert.Equal(t, 3, nb.ServiceCount(accountID))
|
||||
|
||||
// Remove one domain - client should remain.
|
||||
// Remove one service - client should remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one service")
|
||||
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2")
|
||||
|
||||
// Remove another domain - client should still remain.
|
||||
// Remove another service - client should still remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second service")
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add single domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add single service.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
// Remove the only domain - client should be removed.
|
||||
// Note: Stop() may fail since the client never actually connected,
|
||||
// but the entry should still be removed from the map.
|
||||
// Remove the only service - client should be removed.
|
||||
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
|
||||
// After removing all domains, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
// After removing all services, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service")
|
||||
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
@@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
assert.NoError(t, err, "removing from non-existent account should not error")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
|
||||
func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add one domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add one service.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove non-existent domain - should not affect existing domain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
|
||||
// Remove non-existent service - should not affect existing service.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original domain should still be registered.
|
||||
// Original service should still be registered.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain")
|
||||
}
|
||||
|
||||
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
|
||||
@@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
|
||||
account2 := types.AccountID("account-2")
|
||||
account3 := types.AccountID("account-3")
|
||||
|
||||
// Add domains for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
// Add services for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
|
||||
err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
|
||||
|
||||
// Stop all clients.
|
||||
// Note: StopAll may return errors since clients never actually connected,
|
||||
// but the clients should still be removed from the map.
|
||||
_ = nb.StopAll(context.Background())
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
|
||||
@@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) {
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
|
||||
|
||||
// Add clients for different accounts.
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.ClientCount())
|
||||
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount())
|
||||
|
||||
// Adding domain to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
|
||||
// Adding service to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count")
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
|
||||
@@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
// Add first service — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually mark client as started to simulate background startup completing.
|
||||
@@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second domain — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
// Add second service — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, string(accountID), calls[0].accountID)
|
||||
assert.Equal(t, "svc-2", calls[0].serviceID)
|
||||
assert.Equal(t, "domain2.test", calls[0].domain)
|
||||
assert.Equal(t, accountID, calls[0].accountID)
|
||||
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
@@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove one domain — client stays, but disconnection notification fires.
|
||||
// Remove one service — client stays, but disconnection notification fires.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "domain1.test", calls[0].domain)
|
||||
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
133
proxy/internal/tcp/bench_test.go
Normal file
133
proxy/internal/tcp/bench_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real
|
||||
// TLS ClientHello and extracting the SNI. This is the per-connection cost
|
||||
// added to every TLS connection on the main listener.
|
||||
func BenchmarkPeekClientHello_TLS(b *testing.B) {
|
||||
// Pre-generate a ClientHello by capturing what crypto/tls sends.
|
||||
clientConn, serverConn := net.Pipe()
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: "app.example.com",
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
var hello []byte
|
||||
buf := make([]byte, 16384)
|
||||
n, _ := serverConn.Read(buf)
|
||||
hello = make([]byte, n)
|
||||
copy(hello, buf[:n])
|
||||
clientConn.Close()
|
||||
serverConn.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(hello)
|
||||
conn := &readerConn{Reader: r}
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if sni != "app.example.com" {
|
||||
b.Fatalf("unexpected SNI: %q", sni)
|
||||
}
|
||||
// Simulate draining the peeked bytes (what the HTTP server would do).
|
||||
_, _ = io.Copy(io.Discard, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS
|
||||
// connections that hit the fast non-handshake exit path.
|
||||
func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
|
||||
httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(httpReq)
|
||||
conn := &readerConn{Reader: r}
|
||||
_, wrapped, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn
|
||||
// wrapper compared to a plain connection read. The peeked bytes use
|
||||
// io.MultiReader which adds one indirection per Read call.
|
||||
func BenchmarkPeekedConn_Read(b *testing.B) {
|
||||
data := make([]byte, 4096)
|
||||
peeked := make([]byte, 512)
|
||||
buf := make([]byte, 1024)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(data)
|
||||
conn := &readerConn{Reader: r}
|
||||
pc := newPeekedConn(conn, peeked)
|
||||
for {
|
||||
_, err := pc.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractSNI measures just the in-memory SNI parsing cost,
|
||||
// excluding I/O.
|
||||
func BenchmarkExtractSNI(b *testing.B) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: "app.example.com",
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
n, _ := serverConn.Read(buf)
|
||||
payload := make([]byte, n-tlsRecordHeaderLen)
|
||||
copy(payload, buf[tlsRecordHeaderLen:n])
|
||||
clientConn.Close()
|
||||
serverConn.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
sni := extractSNI(payload)
|
||||
if sni != "app.example.com" {
|
||||
b.Fatalf("unexpected SNI: %q", sni)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readerConn wraps an io.Reader as a net.Conn for benchmarking.
|
||||
// Only Read is functional; all other methods are no-ops.
|
||||
type readerConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *readerConn) Read(b []byte) (int, error) {
|
||||
return c.Reader.Read(b)
|
||||
}
|
||||
76
proxy/internal/tcp/chanlistener.go
Normal file
76
proxy/internal/tcp/chanlistener.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// chanListener implements net.Listener by reading connections from a channel.
|
||||
// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS.
|
||||
type chanListener struct {
|
||||
ch chan net.Conn
|
||||
addr net.Addr
|
||||
once sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener {
|
||||
return &chanListener{
|
||||
ch: ch,
|
||||
addr: addr,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection from the channel.
|
||||
func (l *chanListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
select {
|
||||
case conn, ok := <-l.ch:
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
// Drain buffered connections before returning.
|
||||
for {
|
||||
select {
|
||||
case conn, ok := <-l.ch:
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
_ = conn.Close()
|
||||
default:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close signals the listener to stop accepting connections and drains
|
||||
// any buffered connections that have not yet been accepted.
|
||||
func (l *chanListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
close(l.closed)
|
||||
for {
|
||||
select {
|
||||
case conn, ok := <-l.ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
_ = conn.Close()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *chanListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
var _ net.Listener = (*chanListener)(nil)
|
||||
39
proxy/internal/tcp/peekedconn.go
Normal file
39
proxy/internal/tcp/peekedconn.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// peekedConn wraps a net.Conn and prepends previously peeked bytes
|
||||
// so that readers see the full original stream transparently.
|
||||
type peekedConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn {
|
||||
return &peekedConn{
|
||||
Conn: conn,
|
||||
reader: io.MultiReader(bytes.NewReader(peeked), conn),
|
||||
}
|
||||
}
|
||||
|
||||
// Read replays the peeked bytes first, then reads from the underlying conn.
|
||||
func (c *peekedConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// CloseWrite delegates to the underlying connection if it supports
|
||||
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
|
||||
// as an interface hides the concrete type's CloseWrite method, making
|
||||
// half-close a silent no-op for all SNI-routed connections.
|
||||
func (c *peekedConn) CloseWrite() error {
|
||||
if hc, ok := c.Conn.(halfCloser); ok {
|
||||
return hc.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ halfCloser = (*peekedConn)(nil)
|
||||
29
proxy/internal/tcp/proxyprotocol.go
Normal file
29
proxy/internal/tcp/proxyprotocol.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
)
|
||||
|
||||
// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection,
|
||||
// conveying the real client address.
|
||||
func writeProxyProtoV2(client, backend net.Conn) error {
|
||||
tp := proxyproto.TCPv4
|
||||
if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil {
|
||||
tp = proxyproto.TCPv6
|
||||
}
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: tp,
|
||||
SourceAddr: client.RemoteAddr(),
|
||||
DestinationAddr: client.LocalAddr(),
|
||||
}
|
||||
if _, err := header.WriteTo(backend); err != nil {
|
||||
return fmt.Errorf("write PROXY protocol v2 header: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
128
proxy/internal/tcp/proxyprotocol_test.go
Normal file
128
proxy/internal/tcp/proxyprotocol_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteProxyProtoV2_IPv4(t *testing.T) {
|
||||
// Set up a real TCP listener and dial to get connections with real addresses.
|
||||
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
var err error
|
||||
serverConn, err = ln.Accept()
|
||||
if err != nil {
|
||||
t.Error("accept failed:", err)
|
||||
}
|
||||
close(accepted)
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp4", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
<-accepted
|
||||
defer serverConn.Close()
|
||||
|
||||
// Use a pipe as the backend: write the header to one end, read from the other.
|
||||
backendRead, backendWrite := net.Pipe()
|
||||
defer backendRead.Close()
|
||||
defer backendWrite.Close()
|
||||
|
||||
// serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination.
|
||||
writeDone := make(chan error, 1)
|
||||
go func() {
|
||||
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
|
||||
}()
|
||||
|
||||
// Read the PROXY protocol header from the backend read side.
|
||||
header, err := proxyproto.Read(bufio.NewReader(backendRead))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, header, "should have received a proxy protocol header")
|
||||
|
||||
writeErr := <-writeDone
|
||||
require.NoError(t, writeErr)
|
||||
|
||||
assert.Equal(t, byte(2), header.Version, "version should be 2")
|
||||
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
|
||||
assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4")
|
||||
|
||||
// serverConn.RemoteAddr() is the client's address (source in the header).
|
||||
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
|
||||
actualSrc := header.SourceAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
|
||||
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
|
||||
|
||||
// serverConn.LocalAddr() is the server's address (destination in the header).
|
||||
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
|
||||
actualDst := header.DestinationAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
|
||||
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
|
||||
}
|
||||
|
||||
func TestWriteProxyProtoV2_IPv6(t *testing.T) {
|
||||
// Set up a real TCP6 listener on loopback.
|
||||
ln, err := net.Listen("tcp6", "[::1]:0")
|
||||
if err != nil {
|
||||
t.Skip("IPv6 not available:", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
var err error
|
||||
serverConn, err = ln.Accept()
|
||||
if err != nil {
|
||||
t.Error("accept failed:", err)
|
||||
}
|
||||
close(accepted)
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp6", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
<-accepted
|
||||
defer serverConn.Close()
|
||||
|
||||
backendRead, backendWrite := net.Pipe()
|
||||
defer backendRead.Close()
|
||||
defer backendWrite.Close()
|
||||
|
||||
writeDone := make(chan error, 1)
|
||||
go func() {
|
||||
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
|
||||
}()
|
||||
|
||||
header, err := proxyproto.Read(bufio.NewReader(backendRead))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, header, "should have received a proxy protocol header")
|
||||
|
||||
writeErr := <-writeDone
|
||||
require.NoError(t, writeErr)
|
||||
|
||||
assert.Equal(t, byte(2), header.Version, "version should be 2")
|
||||
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
|
||||
assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6")
|
||||
|
||||
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
|
||||
actualSrc := header.SourceAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
|
||||
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
|
||||
|
||||
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
|
||||
actualDst := header.DestinationAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
|
||||
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
|
||||
}
|
||||
156
proxy/internal/tcp/relay.go
Normal file
156
proxy/internal/tcp/relay.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
|
||||
var errIdleTimeout = errors.New("idle timeout")
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
|
||||
// A zero value disables idle timeout checking.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
|
||||
// EOF to the remote by closing the write side while keeping the read
|
||||
// side open so the other direction can drain.
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Relay copies data bidirectionally between src and dst until both
|
||||
// sides are done or the context is canceled. When idleTimeout is
|
||||
// non-zero, each direction's read is deadline-guarded; if no data
|
||||
// flows within the timeout the connection is torn down. When one
|
||||
// direction finishes, it half-closes the write side of the
|
||||
// destination (if supported) to signal EOF, allowing the other
|
||||
// direction to drain gracefully before the full connection teardown.
|
||||
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errSrcToDst, errDstToSrc error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
|
||||
halfClose(dst)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
|
||||
halfClose(src)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
|
||||
logger.Debug("relay closed due to idle timeout")
|
||||
}
|
||||
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
|
||||
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
|
||||
}
|
||||
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
|
||||
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
|
||||
}
|
||||
|
||||
return srcToDst, dstToSrc
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
|
||||
// When idleTimeout > 0 it sets a read deadline on src before each
|
||||
// read and treats a timeout as an idle-triggered close.
|
||||
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
conn, ok := src.(net.Conn)
|
||||
if !ok {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return total, err
|
||||
}
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
n, err := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if netutil.IsTimeout(readErr) {
|
||||
return total, errIdleTimeout
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkedWrite writes buf to dst and returns the number of bytes written.
|
||||
// It guards against short writes and negative counts per io.Copy convention.
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
|
||||
}
|
||||
|
||||
// halfClose attempts to half-close the write side of the connection.
|
||||
// If the connection does not support half-close, this is a no-op.
|
||||
func halfClose(conn net.Conn) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
// Best-effort; the full close will follow shortly.
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
210
proxy/internal/tcp/relay_test.go
Normal file
210
proxy/internal/tcp/relay_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
srcData := []byte("hello from src")
|
||||
dstData := []byte("hello from dst")
|
||||
|
||||
// dst side: write response first, then read + close.
|
||||
go func() {
|
||||
_, _ = dstClient.Write(dstData)
|
||||
buf := make([]byte, 256)
|
||||
_, _ = dstClient.Read(buf)
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
// src side: read the response, then send data + close.
|
||||
go func() {
|
||||
buf := make([]byte, 256)
|
||||
_, _ = srcClient.Read(buf)
|
||||
_, _ = srcClient.Write(srcData)
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
|
||||
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
|
||||
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
|
||||
}
|
||||
|
||||
func TestRelay_ContextCancellation(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
defer srcClient.Close()
|
||||
defer dstClient.Close()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Cancel should cause Relay to return.
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Relay did not return after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelay_OneSideClosed(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
defer dstClient.Close()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
// Close src immediately. Relay should complete without hanging.
|
||||
srcClient.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Relay did not return after one side closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelay_LargeTransfer(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
// 1MB of data.
|
||||
data := make([]byte, 1<<20)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = srcClient.Write(data)
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
received, err := io.ReadAll(dstClient)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if len(received) != len(data) {
|
||||
errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received))
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
|
||||
require.NoError(t, <-errCh)
|
||||
}
|
||||
|
||||
func TestRelay_IdleTimeout(t *testing.T) {
|
||||
// Use real TCP connections so SetReadDeadline works (net.Pipe
|
||||
// does not support deadlines).
|
||||
srcLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srcLn.Close()
|
||||
|
||||
dstLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dstLn.Close()
|
||||
|
||||
srcClient, err := net.Dial("tcp", srcLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srcClient.Close()
|
||||
|
||||
srcServer, err := srcLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dstClient, err := net.Dial("tcp", dstLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dstClient.Close()
|
||||
|
||||
dstServer, err := dstLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
// Send initial data to prove the relay works.
|
||||
go func() {
|
||||
_, _ = srcClient.Write([]byte("ping"))
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
var s2d, d2s int64
|
||||
go func() {
|
||||
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Read the forwarded data on the dst side.
|
||||
buf := make([]byte, 64)
|
||||
n, err := dstClient.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ping", string(buf[:n]))
|
||||
|
||||
// Now stop sending. The relay should close after the idle timeout.
|
||||
select {
|
||||
case <-done:
|
||||
assert.Greater(t, s2d, int64(0), "should have transferred initial data")
|
||||
_ = d2s
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Relay did not exit after idle timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsExpectedError(t *testing.T) {
|
||||
assert.True(t, netutil.IsExpectedError(net.ErrClosed))
|
||||
assert.True(t, netutil.IsExpectedError(context.Canceled))
|
||||
assert.True(t, netutil.IsExpectedError(io.EOF))
|
||||
assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF))
|
||||
}
|
||||
570
proxy/internal/tcp/router.go
Normal file
570
proxy/internal/tcp/router.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// defaultDialTimeout is the fallback dial timeout when no per-route
|
||||
// timeout is configured.
|
||||
const defaultDialTimeout = 30 * time.Second
|
||||
|
||||
// SNIHost is a typed key for SNI hostname lookups.
|
||||
type SNIHost string
|
||||
|
||||
// RouteType specifies how a connection should be handled.
|
||||
type RouteType int
|
||||
|
||||
const (
|
||||
// RouteHTTP routes the connection through the HTTP reverse proxy.
|
||||
RouteHTTP RouteType = iota
|
||||
// RouteTCP relays the connection directly to the backend (TLS passthrough).
|
||||
RouteTCP
|
||||
)
|
||||
|
||||
const (
|
||||
// sniPeekTimeout is the deadline for reading the TLS ClientHello.
|
||||
sniPeekTimeout = 5 * time.Second
|
||||
// DefaultDrainTimeout is the default grace period for in-flight relay
|
||||
// connections to finish during shutdown.
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
// DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router.
|
||||
DefaultMaxRelayConns = 4096
|
||||
// httpChannelBuffer is the capacity of the channel feeding HTTP connections.
|
||||
httpChannelBuffer = 4096
|
||||
)
|
||||
|
||||
// DialResolver returns a DialContextFunc for the given account.
|
||||
type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error)
|
||||
|
||||
// Route describes where a connection for a given SNI should be sent.
|
||||
type Route struct {
|
||||
Type RouteType
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
// Domain is the service's configured domain, used for access log entries.
|
||||
Domain string
|
||||
// Protocol is the frontend protocol (tcp, tls), used for access log entries.
|
||||
Protocol accesslog.Protocol
|
||||
// Target is the backend address for TCP relay (e.g. "10.0.0.5:5432").
|
||||
Target string
|
||||
// ProxyProtocol enables sending a PROXY protocol v2 header to the backend.
|
||||
ProxyProtocol bool
|
||||
// DialTimeout overrides the default dial timeout for this route.
|
||||
// Zero uses defaultDialTimeout.
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
// l4Logger sends layer-4 access log entries to the management server.
|
||||
type l4Logger interface {
|
||||
LogL4(entry accesslog.L4Entry)
|
||||
}
|
||||
|
||||
// RelayObserver receives callbacks for TCP relay lifecycle events.
|
||||
// All methods must be safe for concurrent use.
|
||||
type RelayObserver interface {
|
||||
TCPRelayStarted(accountID types.AccountID)
|
||||
TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64)
|
||||
TCPRelayDialError(accountID types.AccountID)
|
||||
TCPRelayRejected(accountID types.AccountID)
|
||||
}
|
||||
|
||||
// Router accepts raw TCP connections on a shared listener, peeks at
|
||||
// the TLS ClientHello to extract the SNI, and routes the connection
|
||||
// to either the HTTP reverse proxy or a direct TCP relay.
|
||||
type Router struct {
|
||||
logger *log.Logger
|
||||
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
|
||||
httpCh chan net.Conn
|
||||
httpListener *chanListener
|
||||
mu sync.RWMutex
|
||||
routes map[SNIHost][]Route
|
||||
fallback *Route
|
||||
draining bool
|
||||
dialResolve DialResolver
|
||||
activeConns sync.WaitGroup
|
||||
activeRelays sync.WaitGroup
|
||||
relaySem chan struct{}
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
||||
// service derive from its context; canceling it kills them immediately.
|
||||
svcCtxs map[types.ServiceID]context.Context
|
||||
svcCancels map[types.ServiceID]context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRouter creates a new SNI-based connection router.
|
||||
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
|
||||
httpCh := make(chan net.Conn, httpChannelBuffer)
|
||||
return &Router{
|
||||
logger: logger,
|
||||
httpCh: httpCh,
|
||||
httpListener: newChanListener(httpCh, addr),
|
||||
routes: make(map[SNIHost][]Route),
|
||||
dialResolve: dialResolve,
|
||||
relaySem: make(chan struct{}, DefaultMaxRelayConns),
|
||||
svcCtxs: make(map[types.ServiceID]context.Context),
|
||||
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// NewPortRouter creates a Router for a dedicated port without an HTTP
|
||||
// channel. Connections that don't match any SNI route fall through to
|
||||
// the fallback relay (if set) or are closed.
|
||||
func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router {
|
||||
return &Router{
|
||||
logger: logger,
|
||||
routes: make(map[SNIHost][]Route),
|
||||
dialResolve: dialResolve,
|
||||
relaySem: make(chan struct{}, DefaultMaxRelayConns),
|
||||
svcCtxs: make(map[types.ServiceID]context.Context),
|
||||
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPListener returns a net.Listener that yields connections routed
|
||||
// to the HTTP handler. Use this with http.Server.ServeTLS.
|
||||
func (r *Router) HTTPListener() net.Listener {
|
||||
return r.httpListener
|
||||
}
|
||||
|
||||
// AddRoute registers an SNI route. Multiple routes for the same host are
|
||||
// stored and resolved by priority at lookup time (HTTP > TCP).
|
||||
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
||||
func (r *Router) AddRoute(host SNIHost, route Route) {
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
routes := r.routes[host]
|
||||
for i, existing := range routes {
|
||||
if existing.ServiceID == route.ServiceID {
|
||||
r.cancelServiceLocked(route.ServiceID)
|
||||
routes[i] = route
|
||||
return
|
||||
}
|
||||
}
|
||||
r.routes[host] = append(routes, route)
|
||||
}
|
||||
|
||||
// RemoveRoute removes the route for the given host and service ID.
|
||||
// Active relay connections for the service are closed immediately.
|
||||
// If other routes remain for the host, they are preserved.
|
||||
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool {
|
||||
return route.ServiceID == svcID
|
||||
})
|
||||
if len(r.routes[host]) == 0 {
|
||||
delete(r.routes, host)
|
||||
}
|
||||
r.cancelServiceLocked(svcID)
|
||||
}
|
||||
|
||||
// SetFallback registers a catch-all route for connections that don't
|
||||
// match any SNI route. On a port router this handles plain TCP relay;
|
||||
// on the main router it takes priority over the HTTP channel.
|
||||
func (r *Router) SetFallback(route Route) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.fallback = &route
|
||||
}
|
||||
|
||||
// RemoveFallback clears the catch-all fallback route and closes any
|
||||
// active relay connections for the given service.
|
||||
func (r *Router) RemoveFallback(svcID types.ServiceID) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.fallback = nil
|
||||
r.cancelServiceLocked(svcID)
|
||||
}
|
||||
|
||||
// SetObserver sets the relay lifecycle observer. Must be called before Serve.
|
||||
func (r *Router) SetObserver(obs RelayObserver) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.observer = obs
|
||||
}
|
||||
|
||||
// SetAccessLogger sets the L4 access logger. Must be called before Serve.
|
||||
func (r *Router) SetAccessLogger(l l4Logger) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.accessLog = l
|
||||
}
|
||||
|
||||
// getObserver returns the current relay observer under the read lock.
|
||||
func (r *Router) getObserver() RelayObserver {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.observer
|
||||
}
|
||||
|
||||
// IsEmpty returns true when the router has no SNI routes and no fallback.
|
||||
func (r *Router) IsEmpty() bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.routes) == 0 && r.fallback == nil
|
||||
}
|
||||
|
||||
// Serve accepts connections from ln and routes them based on SNI.
|
||||
// It blocks until ctx is canceled or ln is closed, then drains
|
||||
// active relay connections up to DefaultDrainTimeout.
|
||||
func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = ln.Close()
|
||||
if r.httpListener != nil {
|
||||
r.httpListener.Close()
|
||||
}
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ok := r.Drain(DefaultDrainTimeout); !ok {
|
||||
r.logger.Warn("timed out waiting for connections to drain")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
r.logger.Debugf("SNI router accept: %v", err)
|
||||
continue
|
||||
}
|
||||
r.activeConns.Add(1)
|
||||
go func() {
|
||||
defer r.activeConns.Done()
|
||||
r.handleConn(ctx, conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn peeks at the TLS ClientHello and routes the connection.
|
||||
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
|
||||
// fallback port), skip the TLS peek entirely to avoid read errors on
|
||||
// non-TLS connections and reduce latency.
|
||||
if r.isFallbackOnly() {
|
||||
r.handleUnmatched(ctx, conn)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil {
|
||||
r.logger.Debugf("set SNI peek deadline: %v", err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
r.logger.Debugf("SNI peek: %v", err)
|
||||
if wrapped != nil {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := wrapped.SetReadDeadline(time.Time{}); err != nil {
|
||||
r.logger.Debugf("clear SNI peek deadline: %v", err)
|
||||
_ = wrapped.Close()
|
||||
return
|
||||
}
|
||||
|
||||
host := SNIHost(sni)
|
||||
route, ok := r.lookupRoute(host)
|
||||
if !ok {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
return
|
||||
}
|
||||
|
||||
if route.Type == RouteHTTP {
|
||||
r.sendToHTTP(wrapped)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"sni": host,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
}).Warnf("TCP relay: %v", err)
|
||||
_ = wrapped.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// isFallbackOnly returns true when the router has no SNI routes and no HTTP
|
||||
// channel, meaning all connections should go directly to the fallback relay.
|
||||
func (r *Router) isFallbackOnly() bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.routes) == 0 && r.httpCh == nil
|
||||
}
|
||||
|
||||
// handleUnmatched routes a connection that didn't match any SNI route.
|
||||
// This includes ECH/ESNI connections where the cleartext SNI is empty.
|
||||
// It tries the fallback relay first, then the HTTP channel, and closes
|
||||
// the connection if neither is available.
|
||||
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
r.mu.RLock()
|
||||
fb := r.fallback
|
||||
r.mu.RUnlock()
|
||||
|
||||
if fb != nil {
|
||||
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"service_id": fb.ServiceID,
|
||||
"target": fb.Target,
|
||||
}).Warnf("TCP relay (fallback): %v", err)
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
r.sendToHTTP(conn)
|
||||
}
|
||||
|
||||
// lookupRoute returns the highest-priority route for the given SNI host.
|
||||
// HTTP routes take precedence over TCP routes.
|
||||
func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
routes, ok := r.routes[host]
|
||||
if !ok || len(routes) == 0 {
|
||||
return Route{}, false
|
||||
}
|
||||
best := routes[0]
|
||||
for _, route := range routes[1:] {
|
||||
if route.Type < best.Type {
|
||||
best = route
|
||||
}
|
||||
}
|
||||
return best, true
|
||||
}
|
||||
|
||||
// sendToHTTP feeds the connection to the HTTP handler via the channel.
|
||||
// If no HTTP channel is configured (port router), the router is
|
||||
// draining, or the channel is full, the connection is closed.
|
||||
func (r *Router) sendToHTTP(conn net.Conn) {
|
||||
if r.httpCh == nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
draining := r.draining
|
||||
r.mu.RUnlock()
|
||||
|
||||
if draining {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case r.httpCh <- conn:
|
||||
default:
|
||||
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Drain prevents new relay connections from starting and waits for all
|
||||
// in-flight connection handlers and active relays to finish, up to the
|
||||
// given timeout. Returns true if all completed, false on timeout.
|
||||
func (r *Router) Drain(timeout time.Duration) bool {
|
||||
r.mu.Lock()
|
||||
r.draining = true
|
||||
if r.drainDone == nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.activeConns.Wait()
|
||||
r.activeRelays.Wait()
|
||||
close(done)
|
||||
}()
|
||||
r.drainDone = done
|
||||
}
|
||||
done := r.drainDone
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// cancelServiceLocked cancels and removes the context for the given service,
|
||||
// closing all its active relay connections. Must be called with mu held.
|
||||
func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
|
||||
if cancel, ok := r.svcCancels[svcID]; ok {
|
||||
cancel()
|
||||
delete(r.svcCtxs, svcID)
|
||||
delete(r.svcCancels, svcID)
|
||||
}
|
||||
}
|
||||
|
||||
// relayTCP sets up and runs a bidirectional TCP relay.
|
||||
// The caller owns conn and must close it if this method returns an error.
|
||||
// On success (nil error), both conn and backend are closed by the relay.
|
||||
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
|
||||
svcCtx, err := r.acquireRelay(ctx, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
<-r.relaySem
|
||||
r.activeRelays.Done()
|
||||
}()
|
||||
|
||||
backend, err := r.dialBackend(svcCtx, route)
|
||||
if err != nil {
|
||||
obs := r.getObserver()
|
||||
if obs != nil {
|
||||
obs.TCPRelayDialError(route.AccountID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if route.ProxyProtocol {
|
||||
if err := writeProxyProtoV2(conn, backend); err != nil {
|
||||
_ = backend.Close()
|
||||
return fmt.Errorf("write PROXY protocol header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
obs := r.getObserver()
|
||||
if obs != nil {
|
||||
obs.TCPRelayStarted(route.AccountID)
|
||||
}
|
||||
|
||||
entry := r.logger.WithFields(log.Fields{
|
||||
"sni": sni,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
})
|
||||
entry.Debug("TCP relay started")
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s)
|
||||
}
|
||||
entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s)
|
||||
|
||||
r.logL4Entry(route, conn, elapsed, s2d, d2s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// acquireRelay checks draining state, increments activeRelays, and acquires
|
||||
// a semaphore slot. Returns the per-service context on success.
|
||||
// The caller must release the semaphore and call activeRelays.Done() when done.
|
||||
func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) {
|
||||
r.mu.Lock()
|
||||
if r.draining {
|
||||
r.mu.Unlock()
|
||||
return nil, errors.New("router is draining")
|
||||
}
|
||||
r.activeRelays.Add(1)
|
||||
svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID)
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case r.relaySem <- struct{}{}:
|
||||
return svcCtx, nil
|
||||
default:
|
||||
r.activeRelays.Done()
|
||||
obs := r.getObserver()
|
||||
if obs != nil {
|
||||
obs.TCPRelayRejected(route.AccountID)
|
||||
}
|
||||
return nil, errors.New("TCP relay connection limit reached")
|
||||
}
|
||||
}
|
||||
|
||||
// dialBackend resolves the dialer for the route's account and dials the backend.
|
||||
func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) {
|
||||
dialFn, err := r.dialResolve(route.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve dialer: %w", err)
|
||||
}
|
||||
|
||||
dialTimeout := route.DialTimeout
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = defaultDialTimeout
|
||||
}
|
||||
dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout)
|
||||
backend, err := dialFn(dialCtx, "tcp", route.Target)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial backend %s: %w", route.Target, err)
|
||||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// logL4Entry sends a TCP relay access log entry if an access logger is configured.
|
||||
func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) {
|
||||
r.mu.RLock()
|
||||
al := r.accessLog
|
||||
r.mu.RUnlock()
|
||||
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var sourceIP netip.Addr
|
||||
if remote := conn.RemoteAddr(); remote != nil {
|
||||
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
|
||||
sourceIP = ap.Addr().Unmap()
|
||||
}
|
||||
}
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
ServiceID: route.ServiceID,
|
||||
Protocol: route.Protocol,
|
||||
Host: route.Domain,
|
||||
SourceIP: sourceIP,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
BytesUpload: bytesUp,
|
||||
BytesDownload: bytesDown,
|
||||
})
|
||||
}
|
||||
|
||||
// getOrCreateServiceCtxLocked returns the context for a service, creating one
|
||||
// if it doesn't exist yet. The context is a child of the server context.
|
||||
// Must be called with mu held.
|
||||
func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context {
|
||||
if ctx, ok := r.svcCtxs[svcID]; ok {
|
||||
return ctx
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
r.svcCtxs[svcID] = ctx
|
||||
r.svcCancels[svcID] = cancel
|
||||
return ctx
|
||||
}
|
||||
1670
proxy/internal/tcp/router_test.go
Normal file
1670
proxy/internal/tcp/router_test.go
Normal file
File diff suppressed because it is too large
Load Diff
191
proxy/internal/tcp/snipeek.go
Normal file
191
proxy/internal/tcp/snipeek.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
// TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2).
|
||||
tlsRecordHeaderLen = 5
|
||||
// TLS handshake type for ClientHello.
|
||||
handshakeTypeClientHello = 1
|
||||
// TLS ContentType for handshake messages.
|
||||
contentTypeHandshake = 22
|
||||
// SNI extension type (RFC 6066).
|
||||
extensionServerName = 0
|
||||
// SNI host name type.
|
||||
sniHostNameType = 0
|
||||
// maxClientHelloLen caps the ClientHello size we're willing to buffer.
|
||||
maxClientHelloLen = 16384
|
||||
// maxSNILen is the maximum valid DNS hostname length per RFC 1035.
|
||||
maxSNILen = 253
|
||||
)
|
||||
|
||||
// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI
|
||||
// server name, and returns a wrapped connection that replays the peeked
|
||||
// bytes transparently. If the data is not a valid TLS ClientHello or
|
||||
// contains no SNI extension, sni is empty and err is nil.
|
||||
//
|
||||
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
|
||||
// real server name is encrypted inside the encrypted_client_hello
|
||||
// extension. This parser only reads the cleartext server_name extension
|
||||
// (type 0x0000), so ECH connections return sni="" and are routed through
|
||||
// the fallback path (or HTTP channel), which is the correct behavior
|
||||
// for a transparent proxy that does not terminate TLS.
|
||||
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
|
||||
// Read the 5-byte TLS record header into a small stack-friendly buffer.
|
||||
var header [tlsRecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||
return "", nil, fmt.Errorf("read TLS record header: %w", err)
|
||||
}
|
||||
|
||||
if header[0] != contentTypeHandshake {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recordLen == 0 || recordLen > maxClientHelloLen {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
}
|
||||
|
||||
// Single allocation for header + payload. The peekedConn takes
|
||||
// ownership of this buffer, so no further copies are needed.
|
||||
buf := make([]byte, tlsRecordHeaderLen+recordLen)
|
||||
copy(buf, header[:])
|
||||
|
||||
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
|
||||
if err != nil {
|
||||
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
|
||||
}
|
||||
|
||||
sni = extractSNI(buf[tlsRecordHeaderLen:])
|
||||
return sni, newPeekedConn(conn, buf), nil
|
||||
}
|
||||
|
||||
// extractSNI parses a TLS handshake payload to find the SNI extension.
|
||||
// Returns empty string if the payload is not a ClientHello or has no SNI.
|
||||
func extractSNI(payload []byte) string {
|
||||
if len(payload) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if payload[0] != handshakeTypeClientHello {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes, big-endian).
|
||||
handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3])
|
||||
if handshakeLen > len(payload)-4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return parseSNIFromClientHello(payload[4 : 4+handshakeLen])
|
||||
}
|
||||
|
||||
// parseSNIFromClientHello walks the ClientHello message fields to reach
|
||||
// the extensions block and extract the server_name extension value.
|
||||
func parseSNIFromClientHello(msg []byte) string {
|
||||
// ClientHello layout:
|
||||
// ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id
|
||||
if len(msg) < 34 {
|
||||
return ""
|
||||
}
|
||||
|
||||
pos := 34
|
||||
|
||||
// Session ID (variable, 1 byte length prefix).
|
||||
if pos >= len(msg) {
|
||||
return ""
|
||||
}
|
||||
sessionIDLen := int(msg[pos])
|
||||
pos++
|
||||
pos += sessionIDLen
|
||||
|
||||
// Cipher suites (variable, 2 byte length prefix).
|
||||
if pos+2 > len(msg) {
|
||||
return ""
|
||||
}
|
||||
cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2]))
|
||||
pos += 2 + cipherSuitesLen
|
||||
|
||||
// Compression methods (variable, 1 byte length prefix).
|
||||
if pos >= len(msg) {
|
||||
return ""
|
||||
}
|
||||
compMethodsLen := int(msg[pos])
|
||||
pos++
|
||||
pos += compMethodsLen
|
||||
|
||||
// Extensions (variable, 2 byte length prefix).
|
||||
if pos+2 > len(msg) {
|
||||
return ""
|
||||
}
|
||||
extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2]))
|
||||
pos += 2
|
||||
|
||||
extensionsEnd := pos + extensionsLen
|
||||
if extensionsEnd > len(msg) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return findSNIExtension(msg[pos:extensionsEnd])
|
||||
}
|
||||
|
||||
// findSNIExtension iterates over TLS extensions and returns the host
|
||||
// name from the server_name extension, if present.
|
||||
func findSNIExtension(extensions []byte) string {
|
||||
pos := 0
|
||||
for pos+4 <= len(extensions) {
|
||||
extType := binary.BigEndian.Uint16(extensions[pos : pos+2])
|
||||
extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4]))
|
||||
pos += 4
|
||||
|
||||
if pos+extLen > len(extensions) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if extType == extensionServerName {
|
||||
return parseSNIExtensionData(extensions[pos : pos+extLen])
|
||||
}
|
||||
pos += extLen
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseSNIExtensionData parses the ServerNameList structure inside an
|
||||
// SNI extension to extract the host name.
|
||||
func parseSNIExtensionData(data []byte) string {
|
||||
if len(data) < 2 {
|
||||
return ""
|
||||
}
|
||||
listLen := int(binary.BigEndian.Uint16(data[0:2]))
|
||||
if listLen > len(data)-2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
list := data[2 : 2+listLen]
|
||||
pos := 0
|
||||
for pos+3 <= len(list) {
|
||||
nameType := list[pos]
|
||||
nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3]))
|
||||
pos += 3
|
||||
|
||||
if pos+nameLen > len(list) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if nameType == sniHostNameType {
|
||||
name := list[pos : pos+nameLen]
|
||||
if nameLen > maxSNILen || bytes.ContainsRune(name, 0) {
|
||||
return ""
|
||||
}
|
||||
return string(name)
|
||||
}
|
||||
pos += nameLen
|
||||
}
|
||||
return ""
|
||||
}
|
||||
251
proxy/internal/tcp/snipeek_test.go
Normal file
251
proxy/internal/tcp/snipeek_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPeekClientHello_ValidSNI(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
const expectedSNI = "example.com"
|
||||
trailingData := []byte("trailing data after handshake")
|
||||
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: expectedSNI,
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
// The Handshake will send the ClientHello. It will fail because
|
||||
// our server side isn't doing a real TLS handshake, but that's
|
||||
// fine: we only need the ClientHello to be sent.
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
|
||||
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
|
||||
|
||||
// Verify the wrapped connection replays the peeked bytes.
|
||||
// Read the first 5 bytes (TLS record header) to confirm replay.
|
||||
buf := make([]byte, 5)
|
||||
n, err := wrapped.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, n)
|
||||
assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type")
|
||||
|
||||
// Write trailing data from the client side and verify it arrives
|
||||
// through the wrapped connection after the peeked bytes.
|
||||
go func() {
|
||||
_, _ = clientConn.Write(trailingData)
|
||||
}()
|
||||
|
||||
// Drain the rest of the peeked ClientHello first.
|
||||
peekedRest := make([]byte, 16384)
|
||||
_, _ = wrapped.Read(peekedRest)
|
||||
|
||||
got := make([]byte, len(trailingData))
|
||||
n, err = io.ReadFull(wrapped, got)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, trailingData, got[:n])
|
||||
}
|
||||
|
||||
func TestPeekClientHello_MultipleSNIs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
expectedSNI string
|
||||
}{
|
||||
{"simple domain", "example.com", "example.com"},
|
||||
{"subdomain", "sub.example.com", "sub.example.com"},
|
||||
{"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: tt.serverName,
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedSNI, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_NonTLSData(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
// Send plain HTTP data (not TLS).
|
||||
httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
go func() {
|
||||
_, _ = clientConn.Write(httpData)
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
|
||||
assert.NotNil(t, wrapped)
|
||||
|
||||
// Verify the wrapped connection still provides the original data.
|
||||
buf := make([]byte, len(httpData))
|
||||
n, err := io.ReadFull(wrapped, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data")
|
||||
}
|
||||
|
||||
func TestPeekClientHello_TruncatedHeader(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
|
||||
// Write only 3 bytes then close, fewer than the 5-byte TLS header.
|
||||
go func() {
|
||||
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01})
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated header")
|
||||
}
|
||||
|
||||
func TestPeekClientHello_TruncatedPayload(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
|
||||
// Write a valid TLS header claiming 100 bytes, but only send 10.
|
||||
go func() {
|
||||
header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed
|
||||
_, _ = clientConn.Write(header)
|
||||
_, _ = clientConn.Write(make([]byte, 10))
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated payload")
|
||||
}
|
||||
|
||||
func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
// TLS handshake header with zero-length payload.
|
||||
go func() {
|
||||
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
}
|
||||
|
||||
func TestExtractSNI_InvalidPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
}{
|
||||
{"nil", nil},
|
||||
{"empty", []byte{}},
|
||||
{"too short", []byte{0x01, 0x00}},
|
||||
{"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}},
|
||||
{"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Empty(t, extractSNI(tt.payload))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekedConn_CloseWrite(t *testing.T) {
|
||||
t.Run("delegates to underlying TCPConn", func(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err == nil {
|
||||
accepted <- c
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
server := <-accepted
|
||||
defer server.Close()
|
||||
|
||||
wrapped := newPeekedConn(server, []byte("peeked"))
|
||||
|
||||
// CloseWrite should succeed on a real TCP connection.
|
||||
err = wrapped.CloseWrite()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// The client should see EOF on reads after CloseWrite.
|
||||
buf := make([]byte, 1)
|
||||
_, err = client.Read(buf)
|
||||
assert.Equal(t, io.EOF, err, "client should see EOF after half-close")
|
||||
})
|
||||
|
||||
t.Run("no-op on non-halfcloser", func(t *testing.T) {
|
||||
// net.Pipe does not implement CloseWrite.
|
||||
_, server := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
wrapped := newPeekedConn(server, []byte("peeked"))
|
||||
err := wrapped.CloseWrite()
|
||||
assert.NoError(t, err, "should be no-op on non-halfcloser")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeekedConn_ReplayAndPassthrough(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
peeked := []byte("peeked-data")
|
||||
subsequent := []byte("subsequent-data")
|
||||
|
||||
wrapped := newPeekedConn(serverConn, peeked)
|
||||
|
||||
go func() {
|
||||
_, _ = clientConn.Write(subsequent)
|
||||
}()
|
||||
|
||||
// Read should return peeked data first.
|
||||
buf := make([]byte, len(peeked))
|
||||
n, err := io.ReadFull(wrapped, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peeked, buf[:n])
|
||||
|
||||
// Then subsequent data from the real connection.
|
||||
buf = make([]byte, len(subsequent))
|
||||
n, err = io.ReadFull(wrapped, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, subsequent, buf[:n])
|
||||
}
|
||||
@@ -1,5 +1,56 @@
|
||||
// Package types defines common types used across the proxy package.
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccountID represents a unique identifier for a NetBird account.
|
||||
type AccountID string
|
||||
|
||||
// ServiceID represents a unique identifier for a proxy service.
|
||||
type ServiceID string
|
||||
|
||||
// ServiceMode describes how a reverse proxy service is exposed.
|
||||
type ServiceMode string
|
||||
|
||||
const (
|
||||
ServiceModeHTTP ServiceMode = "http"
|
||||
ServiceModeTCP ServiceMode = "tcp"
|
||||
ServiceModeUDP ServiceMode = "udp"
|
||||
ServiceModeTLS ServiceMode = "tls"
|
||||
)
|
||||
|
||||
// IsL4 returns true for TCP, UDP, and TLS modes.
|
||||
func (m ServiceMode) IsL4() bool {
|
||||
return m == ServiceModeTCP || m == ServiceModeUDP || m == ServiceModeTLS
|
||||
}
|
||||
|
||||
// RelayDirection indicates the direction of a relayed packet.
|
||||
type RelayDirection string
|
||||
|
||||
const (
|
||||
RelayDirectionClientToBackend RelayDirection = "client_to_backend"
|
||||
RelayDirectionBackendToClient RelayDirection = "backend_to_client"
|
||||
)
|
||||
|
||||
// DialContextFunc dials a backend through the WireGuard tunnel.
|
||||
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// dialTimeoutKey is the context key for a per-request dial timeout.
|
||||
type dialTimeoutKey struct{}
|
||||
|
||||
// WithDialTimeout returns a context carrying a dial timeout that
|
||||
// DialContext wrappers can use to scope the timeout to just the
|
||||
// connection establishment phase.
|
||||
func WithDialTimeout(ctx context.Context, d time.Duration) context.Context {
|
||||
return context.WithValue(ctx, dialTimeoutKey{}, d)
|
||||
}
|
||||
|
||||
// DialTimeoutFromContext returns the dial timeout from the context, if set.
|
||||
func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) {
|
||||
d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration)
|
||||
return d, ok && d > 0
|
||||
}
|
||||
|
||||
54
proxy/internal/types/types_test.go
Normal file
54
proxy/internal/types/types_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestServiceMode_IsL4(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode ServiceMode
|
||||
want bool
|
||||
}{
|
||||
{ServiceModeHTTP, false},
|
||||
{ServiceModeTCP, true},
|
||||
{ServiceModeUDP, true},
|
||||
{ServiceModeTLS, true},
|
||||
{ServiceMode("unknown"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.mode), func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, tt.mode.IsL4())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTimeoutContext(t *testing.T) {
|
||||
t.Run("round trip", func(t *testing.T) {
|
||||
ctx := WithDialTimeout(context.Background(), 5*time.Second)
|
||||
d, ok := DialTimeoutFromContext(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 5*time.Second, d)
|
||||
})
|
||||
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
_, ok := DialTimeoutFromContext(context.Background())
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("zero returns false", func(t *testing.T) {
|
||||
ctx := WithDialTimeout(context.Background(), 0)
|
||||
_, ok := DialTimeoutFromContext(ctx)
|
||||
assert.False(t, ok, "zero duration should return ok=false")
|
||||
})
|
||||
|
||||
t.Run("negative returns false", func(t *testing.T) {
|
||||
ctx := WithDialTimeout(context.Background(), -1*time.Second)
|
||||
_, ok := DialTimeoutFromContext(ctx)
|
||||
assert.False(t, ok, "negative duration should return ok=false")
|
||||
})
|
||||
}
|
||||
496
proxy/internal/udp/relay.go
Normal file
496
proxy/internal/udp/relay.go
Normal file
@@ -0,0 +1,496 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultSessionTTL is the default idle timeout for UDP sessions before cleanup.
|
||||
DefaultSessionTTL = 30 * time.Second
|
||||
// cleanupInterval is how often the cleaner goroutine runs.
|
||||
cleanupInterval = time.Minute
|
||||
// maxPacketSize is the maximum UDP packet size we'll handle.
|
||||
maxPacketSize = 65535
|
||||
// DefaultMaxSessions is the default cap on concurrent UDP sessions per relay.
|
||||
DefaultMaxSessions = 1024
|
||||
// sessionCreateRate limits new session creation per second.
|
||||
sessionCreateRate = 50
|
||||
// sessionCreateBurst is the burst allowance for session creation.
|
||||
sessionCreateBurst = 100
|
||||
// defaultDialTimeout is the fallback dial timeout for backend connections.
|
||||
defaultDialTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// l4Logger sends layer-4 access log entries to the management server.
|
||||
type l4Logger interface {
|
||||
LogL4(entry accesslog.L4Entry)
|
||||
}
|
||||
|
||||
// SessionObserver receives callbacks for UDP session lifecycle events.
|
||||
// All methods must be safe for concurrent use.
|
||||
type SessionObserver interface {
|
||||
UDPSessionStarted(accountID types.AccountID)
|
||||
UDPSessionEnded(accountID types.AccountID)
|
||||
UDPSessionDialError(accountID types.AccountID)
|
||||
UDPSessionRejected(accountID types.AccountID)
|
||||
UDPPacketRelayed(direction types.RelayDirection, bytes int)
|
||||
}
|
||||
|
||||
// clientAddr is a typed key for UDP session lookups.
|
||||
type clientAddr string
|
||||
|
||||
// Relay listens for incoming UDP packets on a dedicated port and
|
||||
// maintains per-client sessions that relay packets to a backend
|
||||
// through the WireGuard tunnel.
|
||||
type Relay struct {
|
||||
logger *log.Entry
|
||||
listener net.PacketConn
|
||||
target string
|
||||
domain string
|
||||
accountID types.AccountID
|
||||
serviceID types.ServiceID
|
||||
dialFunc types.DialContextFunc
|
||||
dialTimeout time.Duration
|
||||
sessionTTL time.Duration
|
||||
maxSessions int
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[clientAddr]*session
|
||||
|
||||
bufPool sync.Pool
|
||||
sessLimiter *rate.Limiter
|
||||
sessWg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
observer SessionObserver
|
||||
accessLog l4Logger
|
||||
}
|
||||
|
||||
type session struct {
|
||||
backend net.Conn
|
||||
addr net.Addr
|
||||
createdAt time.Time
|
||||
// lastSeen stores the last activity timestamp as unix nanoseconds.
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
// bytesIn tracks total bytes received from the client.
|
||||
bytesIn atomic.Int64
|
||||
// bytesOut tracks total bytes sent back to the client.
|
||||
bytesOut atomic.Int64
|
||||
}
|
||||
|
||||
func (s *session) updateLastSeen() {
|
||||
s.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (s *session) idleDuration() time.Duration {
|
||||
return time.Since(time.Unix(0, s.lastSeen.Load()))
|
||||
}
|
||||
|
||||
// RelayConfig holds the configuration for a UDP relay.
|
||||
type RelayConfig struct {
|
||||
Logger *log.Entry
|
||||
Listener net.PacketConn
|
||||
Target string
|
||||
Domain string
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
DialFunc types.DialContextFunc
|
||||
DialTimeout time.Duration
|
||||
SessionTTL time.Duration
|
||||
MaxSessions int
|
||||
AccessLog l4Logger
|
||||
}
|
||||
|
||||
// New creates a UDP relay for the given listener and backend target.
|
||||
// MaxSessions caps the number of concurrent sessions; use 0 for DefaultMaxSessions.
|
||||
// DialTimeout controls how long to wait for backend connections; use 0 for default.
|
||||
// SessionTTL is the idle timeout before a session is reaped; use 0 for DefaultSessionTTL.
|
||||
func New(parentCtx context.Context, cfg RelayConfig) *Relay {
|
||||
maxSessions := cfg.MaxSessions
|
||||
dialTimeout := cfg.DialTimeout
|
||||
sessionTTL := cfg.SessionTTL
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = DefaultMaxSessions
|
||||
}
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = defaultDialTimeout
|
||||
}
|
||||
if sessionTTL <= 0 {
|
||||
sessionTTL = DefaultSessionTTL
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &Relay{
|
||||
logger: cfg.Logger,
|
||||
listener: cfg.Listener,
|
||||
target: cfg.Target,
|
||||
domain: cfg.Domain,
|
||||
accountID: cfg.AccountID,
|
||||
serviceID: cfg.ServiceID,
|
||||
accessLog: cfg.AccessLog,
|
||||
dialFunc: cfg.DialFunc,
|
||||
dialTimeout: dialTimeout,
|
||||
sessionTTL: sessionTTL,
|
||||
maxSessions: maxSessions,
|
||||
sessions: make(map[clientAddr]*session),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, maxPacketSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
sessLimiter: rate.NewLimiter(sessionCreateRate, sessionCreateBurst),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceID returns the service ID associated with this relay.
|
||||
func (r *Relay) ServiceID() types.ServiceID {
|
||||
return r.serviceID
|
||||
}
|
||||
|
||||
// SetObserver sets the session lifecycle observer. Must be called before Serve.
|
||||
func (r *Relay) SetObserver(obs SessionObserver) {
|
||||
r.observer = obs
|
||||
}
|
||||
|
||||
// Serve starts the relay loop. It blocks until the context is canceled
|
||||
// or the listener is closed.
|
||||
func (r *Relay) Serve() {
|
||||
go r.cleanupLoop()
|
||||
|
||||
for {
|
||||
bufp := r.bufPool.Get().(*[]byte)
|
||||
buf := *bufp
|
||||
|
||||
n, addr, err := r.listener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
r.bufPool.Put(bufp)
|
||||
if r.ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
r.logger.Debugf("UDP read: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
data := buf[:n]
|
||||
sess, err := r.getOrCreateSession(addr)
|
||||
if err != nil {
|
||||
r.bufPool.Put(bufp)
|
||||
r.logger.Debugf("create UDP session for %s: %v", addr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
sess.updateLastSeen()
|
||||
|
||||
nw, err := sess.backend.Write(data)
|
||||
if err != nil {
|
||||
r.bufPool.Put(bufp)
|
||||
if !netutil.IsExpectedError(err) {
|
||||
r.logger.Debugf("UDP write to backend for %s: %v", addr, err)
|
||||
}
|
||||
r.removeSession(sess)
|
||||
continue
|
||||
}
|
||||
sess.bytesIn.Add(int64(nw))
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPPacketRelayed(types.RelayDirectionClientToBackend, nw)
|
||||
}
|
||||
r.bufPool.Put(bufp)
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateSession returns an existing session or creates a new one.
|
||||
func (r *Relay) getOrCreateSession(addr net.Addr) (*session, error) {
|
||||
key := clientAddr(addr.String())
|
||||
|
||||
r.mu.RLock()
|
||||
sess, ok := r.sessions[key]
|
||||
r.mu.RUnlock()
|
||||
if ok && sess != nil {
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Check before taking the write lock: if the relay is shutting down,
|
||||
// don't create new sessions. This prevents orphaned goroutines when
|
||||
// Serve() processes a packet that was already read before Close().
|
||||
if r.ctx.Err() != nil {
|
||||
return nil, r.ctx.Err()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
|
||||
if sess, ok = r.sessions[key]; ok && sess != nil {
|
||||
r.mu.Unlock()
|
||||
return sess, nil
|
||||
}
|
||||
if ok {
|
||||
// Another goroutine is dialing for this key, skip.
|
||||
r.mu.Unlock()
|
||||
return nil, fmt.Errorf("session dial in progress for %s", key)
|
||||
}
|
||||
|
||||
if len(r.sessions) >= r.maxSessions {
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionRejected(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("session limit reached (%d)", r.maxSessions)
|
||||
}
|
||||
|
||||
if !r.sessLimiter.Allow() {
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionRejected(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("session creation rate limited")
|
||||
}
|
||||
|
||||
// Reserve the slot with a nil session so concurrent callers for the same
|
||||
// key see it exists and wait. Release the lock before dialing.
|
||||
r.sessions[key] = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(r.ctx, r.dialTimeout)
|
||||
backend, err := r.dialFunc(dialCtx, "udp", r.target)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
r.mu.Lock()
|
||||
delete(r.sessions, key)
|
||||
r.mu.Unlock()
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionDialError(r.accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("dial backend %s: %w", r.target, err)
|
||||
}
|
||||
|
||||
sessCtx, sessCancel := context.WithCancel(r.ctx)
|
||||
sess = &session{
|
||||
backend: backend,
|
||||
addr: addr,
|
||||
createdAt: time.Now(),
|
||||
cancel: sessCancel,
|
||||
}
|
||||
sess.updateLastSeen()
|
||||
|
||||
r.mu.Lock()
|
||||
r.sessions[key] = sess
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionStarted(r.accountID)
|
||||
}
|
||||
|
||||
r.sessWg.Go(func() {
|
||||
r.relayBackendToClient(sessCtx, sess)
|
||||
})
|
||||
|
||||
r.logger.Debugf("UDP session created for %s", addr)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// relayBackendToClient reads packets from the backend and writes them
|
||||
// back to the client through the public-facing listener.
|
||||
func (r *Relay) relayBackendToClient(ctx context.Context, sess *session) {
|
||||
bufp := r.bufPool.Get().(*[]byte)
|
||||
defer r.bufPool.Put(bufp)
|
||||
defer r.removeSession(sess)
|
||||
|
||||
for ctx.Err() == nil {
|
||||
data, ok := r.readBackendPacket(sess, *bufp)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sess.updateLastSeen()
|
||||
|
||||
nw, err := r.listener.WriteTo(data, sess.addr)
|
||||
if err != nil {
|
||||
if !netutil.IsExpectedError(err) {
|
||||
r.logger.Debugf("UDP write to client %s: %v", sess.addr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
sess.bytesOut.Add(int64(nw))
|
||||
|
||||
if r.observer != nil {
|
||||
r.observer.UDPPacketRelayed(types.RelayDirectionBackendToClient, nw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readBackendPacket reads one packet from the backend with an idle deadline.
|
||||
// Returns (data, true) on success, (nil, true) on idle timeout that should
|
||||
// retry, or (nil, false) when the session should be torn down.
|
||||
func (r *Relay) readBackendPacket(sess *session, buf []byte) ([]byte, bool) {
|
||||
if err := sess.backend.SetReadDeadline(time.Now().Add(r.sessionTTL)); err != nil {
|
||||
r.logger.Debugf("set backend read deadline for %s: %v", sess.addr, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
n, err := sess.backend.Read(buf)
|
||||
if err != nil {
|
||||
if netutil.IsTimeout(err) {
|
||||
if sess.idleDuration() > r.sessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
if !netutil.IsExpectedError(err) {
|
||||
r.logger.Debugf("UDP read from backend for %s: %v", sess.addr, err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return buf[:n], true
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes idle sessions.
|
||||
func (r *Relay) cleanupLoop() {
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.cleanupIdleSessions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupIdleSessions closes sessions that have been idle for too long.
|
||||
func (r *Relay) cleanupIdleSessions() {
|
||||
var expired []*session
|
||||
|
||||
r.mu.Lock()
|
||||
for key, sess := range r.sessions {
|
||||
if sess == nil {
|
||||
continue
|
||||
}
|
||||
idle := sess.idleDuration()
|
||||
if idle > r.sessionTTL {
|
||||
r.logger.Debugf("UDP session %s idle for %s, closing (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, idle, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
delete(r.sessions, key)
|
||||
sess.cancel()
|
||||
if err := sess.backend.Close(); err != nil {
|
||||
r.logger.Debugf("close idle session %s backend: %v", sess.addr, err)
|
||||
}
|
||||
expired = append(expired, sess)
|
||||
}
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, sess := range expired {
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
}
|
||||
|
||||
// removeSession removes a session from the map if it still matches the
|
||||
// given pointer. This is safe to call concurrently with cleanupIdleSessions
|
||||
// because the identity check prevents double-close when both paths race.
|
||||
func (r *Relay) removeSession(sess *session) {
|
||||
r.mu.Lock()
|
||||
key := clientAddr(sess.addr.String())
|
||||
removed := r.sessions[key] == sess
|
||||
if removed {
|
||||
delete(r.sessions, key)
|
||||
sess.cancel()
|
||||
if err := sess.backend.Close(); err != nil {
|
||||
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
|
||||
}
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if removed {
|
||||
r.logger.Debugf("UDP session %s ended (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
}
|
||||
|
||||
// logSessionEnd sends an access log entry for a completed UDP session.
|
||||
func (r *Relay) logSessionEnd(sess *session) {
|
||||
if r.accessLog == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var sourceIP netip.Addr
|
||||
if ap, err := netip.ParseAddrPort(sess.addr.String()); err == nil {
|
||||
sourceIP = ap.Addr().Unmap()
|
||||
}
|
||||
|
||||
r.accessLog.LogL4(accesslog.L4Entry{
|
||||
AccountID: r.accountID,
|
||||
ServiceID: r.serviceID,
|
||||
Protocol: accesslog.ProtocolUDP,
|
||||
Host: r.domain,
|
||||
SourceIP: sourceIP,
|
||||
DurationMs: time.Unix(0, sess.lastSeen.Load()).Sub(sess.createdAt).Milliseconds(),
|
||||
BytesUpload: sess.bytesIn.Load(),
|
||||
BytesDownload: sess.bytesOut.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
// Close stops the relay, waits for all session goroutines to exit,
|
||||
// and cleans up remaining sessions.
|
||||
func (r *Relay) Close() {
|
||||
r.cancel()
|
||||
if err := r.listener.Close(); err != nil {
|
||||
r.logger.Debugf("close UDP listener: %v", err)
|
||||
}
|
||||
|
||||
var closedSessions []*session
|
||||
r.mu.Lock()
|
||||
for key, sess := range r.sessions {
|
||||
if sess == nil {
|
||||
delete(r.sessions, key)
|
||||
continue
|
||||
}
|
||||
r.logger.Debugf("UDP session %s closed (client→backend: %d bytes, backend→client: %d bytes)",
|
||||
sess.addr, sess.bytesIn.Load(), sess.bytesOut.Load())
|
||||
sess.cancel()
|
||||
if err := sess.backend.Close(); err != nil {
|
||||
r.logger.Debugf("close session %s backend: %v", sess.addr, err)
|
||||
}
|
||||
delete(r.sessions, key)
|
||||
closedSessions = append(closedSessions, sess)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, sess := range closedSessions {
|
||||
if r.observer != nil {
|
||||
r.observer.UDPSessionEnded(r.accountID)
|
||||
}
|
||||
r.logSessionEnd(sess)
|
||||
}
|
||||
|
||||
r.sessWg.Wait()
|
||||
}
|
||||
493
proxy/internal/udp/relay_test.go
Normal file
493
proxy/internal/udp/relay_test.go
Normal file
@@ -0,0 +1,493 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func TestRelay_BasicPacketExchange(t *testing.T) {
|
||||
// Set up a UDP backend that echoes packets.
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set up the relay's public-facing listener.
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
backendAddr := backend.LocalAddr().String()
|
||||
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backendAddr, DialFunc: dialFunc})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Create a client and send a packet to the relay.
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
testData := []byte("hello UDP relay")
|
||||
_, err = client.Write(testData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read the echoed response.
|
||||
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := client.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, buf[:n], "should receive echoed packet")
|
||||
}
|
||||
|
||||
func TestRelay_MultipleClients(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Two clients, each should get their own session.
|
||||
for i, msg := range []string{"client-1", "client-2"} {
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err, "client %d", i)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Write([]byte(msg))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := client.Read(buf)
|
||||
require.NoError(t, err, "client %d read", i)
|
||||
assert.Equal(t, msg, string(buf[:n]), "client %d should get own echo", i)
|
||||
}
|
||||
|
||||
// Verify two sessions were created.
|
||||
relay.mu.RLock()
|
||||
sessionCount := len(relay.sessions)
|
||||
relay.mu.RUnlock()
|
||||
assert.Equal(t, 2, sessionCount, "should have two sessions")
|
||||
}
|
||||
|
||||
func TestRelay_Close(t *testing.T) {
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: "127.0.0.1:9999", DialFunc: dialFunc})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
relay.Serve()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
relay.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Serve did not return after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelay_SessionCleanup(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Create a session.
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
_, err = client.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := client.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client.Read(buf)
|
||||
require.NoError(t, err)
|
||||
client.Close()
|
||||
|
||||
// Verify session exists.
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 1, len(relay.sessions))
|
||||
relay.mu.RUnlock()
|
||||
|
||||
// Make session appear idle by setting lastSeen to the past.
|
||||
relay.mu.Lock()
|
||||
for _, sess := range relay.sessions {
|
||||
sess.lastSeen.Store(time.Now().Add(-2 * DefaultSessionTTL).UnixNano())
|
||||
}
|
||||
relay.mu.Unlock()
|
||||
|
||||
// Trigger cleanup manually.
|
||||
relay.cleanupIdleSessions()
|
||||
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 0, len(relay.sessions), "idle sessions should be cleaned up")
|
||||
relay.mu.RUnlock()
|
||||
}
|
||||
|
||||
// TestRelay_CloseAndRecreate verifies that closing a relay and creating a new
|
||||
// one on the same port works cleanly (simulates port mapping modify cycle).
|
||||
func TestRelay_CloseAndRecreate(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
// First relay.
|
||||
ln1, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
relay1 := New(ctx, RelayConfig{Logger: logger, Listener: ln1, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay1.Serve()
|
||||
|
||||
client1, err := net.Dial("udp", ln1.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
_, err = client1.Write([]byte("relay1"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client1.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 1024)
|
||||
n, err := client1.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "relay1", string(buf[:n]))
|
||||
client1.Close()
|
||||
|
||||
// Close first relay.
|
||||
relay1.Close()
|
||||
|
||||
// Second relay on same port.
|
||||
port := ln1.LocalAddr().(*net.UDPAddr).Port
|
||||
ln2, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
require.NoError(t, err)
|
||||
|
||||
relay2 := New(ctx, RelayConfig{Logger: logger, Listener: ln2, Target: backend.LocalAddr().String(), DialFunc: dialFunc})
|
||||
go relay2.Serve()
|
||||
defer relay2.Close()
|
||||
|
||||
client2, err := net.Dial("udp", ln2.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer client2.Close()
|
||||
_, err = client2.Write([]byte("relay2"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client2.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
n, err = client2.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "relay2", string(buf[:n]), "second relay should work on same port")
|
||||
}
|
||||
|
||||
func TestRelay_SessionLimit(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
// Create a relay with a max of 2 sessions.
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), DialFunc: dialFunc, MaxSessions: 2})
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Create 2 clients to fill up the session limit.
|
||||
for i := range 2 {
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err, "client %d", i)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client.Read(buf)
|
||||
require.NoError(t, err, "client %d should get response", i)
|
||||
}
|
||||
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 2, len(relay.sessions), "should have exactly 2 sessions")
|
||||
relay.mu.RUnlock()
|
||||
|
||||
// Third client should get its packet dropped (session creation fails).
|
||||
client3, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer client3.Close()
|
||||
|
||||
_, err = client3.Write([]byte("should be dropped"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, client3.SetReadDeadline(time.Now().Add(500*time.Millisecond)))
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client3.Read(buf)
|
||||
assert.Error(t, err, "third client should time out because session was rejected")
|
||||
|
||||
relay.mu.RLock()
|
||||
assert.Equal(t, 2, len(relay.sessions), "session count should not exceed limit")
|
||||
relay.mu.RUnlock()
|
||||
}
|
||||
|
||||
// testObserver records UDP session lifecycle events for test assertions.
|
||||
type testObserver struct {
|
||||
mu sync.Mutex
|
||||
started int
|
||||
ended int
|
||||
rejected int
|
||||
dialErr int
|
||||
packets int
|
||||
bytes int
|
||||
}
|
||||
|
||||
func (o *testObserver) UDPSessionStarted(types.AccountID) { o.mu.Lock(); o.started++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPSessionEnded(types.AccountID) { o.mu.Lock(); o.ended++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPSessionDialError(types.AccountID) { o.mu.Lock(); o.dialErr++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPSessionRejected(types.AccountID) { o.mu.Lock(); o.rejected++; o.mu.Unlock() }
|
||||
func (o *testObserver) UDPPacketRelayed(_ types.RelayDirection, b int) {
|
||||
o.mu.Lock()
|
||||
o.packets++
|
||||
o.bytes += b
|
||||
o.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestRelay_CloseFiresObserverEnded(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
obs := &testObserver{}
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc})
|
||||
relay.SetObserver(obs)
|
||||
go relay.Serve()
|
||||
|
||||
// Create two sessions.
|
||||
for i := range 2 {
|
||||
client, err := net.Dial("udp", listener.LocalAddr().String())
|
||||
require.NoError(t, err, "client %d", i)
|
||||
|
||||
_, err = client.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, client.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 1024)
|
||||
_, err = client.Read(buf)
|
||||
require.NoError(t, err)
|
||||
client.Close()
|
||||
}
|
||||
|
||||
obs.mu.Lock()
|
||||
assert.Equal(t, 2, obs.started, "should have 2 started events")
|
||||
obs.mu.Unlock()
|
||||
|
||||
// Close should fire UDPSessionEnded for all remaining sessions.
|
||||
relay.Close()
|
||||
|
||||
obs.mu.Lock()
|
||||
assert.Equal(t, 2, obs.ended, "Close should fire UDPSessionEnded for each session")
|
||||
obs.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestRelay_SessionRateLimit(t *testing.T) {
|
||||
backend, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, addr, err := backend.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = backend.WriteTo(buf[:n], addr)
|
||||
}
|
||||
}()
|
||||
|
||||
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
dialFunc := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
obs := &testObserver{}
|
||||
// High max sessions (1000) but the relay uses a rate limiter internally
|
||||
// (default: 50/s burst 100). We exhaust the burst by creating sessions
|
||||
// rapidly, then verify that subsequent creates are rejected.
|
||||
relay := New(ctx, RelayConfig{Logger: logger, Listener: listener, Target: backend.LocalAddr().String(), AccountID: "test-acct", DialFunc: dialFunc, MaxSessions: 1000})
|
||||
relay.SetObserver(obs)
|
||||
go relay.Serve()
|
||||
defer relay.Close()
|
||||
|
||||
// Exhaust the burst by calling getOrCreateSession directly with
|
||||
// synthetic addresses. This is faster than real UDP round-trips.
|
||||
for i := range sessionCreateBurst + 20 {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(10, 0, byte(i/256), byte(i%256)), Port: 10000 + i}
|
||||
_, _ = relay.getOrCreateSession(addr)
|
||||
}
|
||||
|
||||
obs.mu.Lock()
|
||||
rejected := obs.rejected
|
||||
obs.mu.Unlock()
|
||||
|
||||
assert.Greater(t, rejected, 0, "some sessions should be rate-limited")
|
||||
}
|
||||
@@ -243,6 +243,10 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
// storeBackedServiceManager reads directly from the real store.
|
||||
type storeBackedServiceManager struct {
|
||||
store store.Store
|
||||
@@ -505,15 +509,15 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
nil,
|
||||
"",
|
||||
0,
|
||||
mapping.GetAccountId(),
|
||||
mapping.GetId(),
|
||||
proxytypes.AccountID(mapping.GetAccountId()),
|
||||
proxytypes.ServiceID(mapping.GetId()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply to real proxy (idempotent)
|
||||
proxyHandler.AddMapping(proxy.Mapping{
|
||||
Host: mapping.GetDomain(),
|
||||
ID: mapping.GetId(),
|
||||
ID: proxytypes.ServiceID(mapping.GetId()),
|
||||
AccountID: proxytypes.AccountID(mapping.GetAccountId()),
|
||||
})
|
||||
}
|
||||
|
||||
782
proxy/server.go
782
proxy/server.go
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user