Feature/embedded STUN (#5062)

This commit is contained in:
Misha Bragin
2026-01-14 13:13:30 +01:00
committed by GitHub
parent 00b747ad5d
commit ff10498a8b
5 changed files with 806 additions and 45 deletions

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
)
@@ -43,6 +45,10 @@ type Config struct {
LogLevel string
LogFile string
HealthcheckListenAddress string
// STUN server configuration
EnableSTUN bool
STUNPorts []int
STUNLogLevel string
}
func (c Config) Validate() error {
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
if c.AuthSecret == "" {
return fmt.Errorf("auth secret is required")
}
// Validate STUN configuration
if c.EnableSTUN {
if len(c.STUNPorts) == 0 {
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
}
seen := make(map[int]bool)
for _, port := range c.STUNPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d", port)
}
seen[port] = true
}
}
return nil
}
@@ -91,6 +116,9 @@ func init() {
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
setFlagsFromEnvVars(rootCmd)
}
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err)
}
// Resource creation phase (fail fast before starting any goroutines)
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start metrics server: %v", err)
}
}()
srvListenerCfg := server.ListenerConfig{
Address: cobraConfig.ListenAddress,
}
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
}
srvListenerCfg.TLSConfig = tlsConfig
// Create STUN listeners early to fail fast
stunListeners, err := createSTUNListeners()
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
TLSSupport: tlsSupport,
}
srv, err := server.NewServer(cfg)
srv, err := createRelayServer(cfg)
if err != nil {
log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err)
cleanupSTUNListeners(stunListeners)
return err
}
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := createHealthCheck(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return err
}
var stunServer *stun.Server
if len(stunListeners) > 0 {
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
}
// Start all servers (only after all resources are successfully created)
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
wg.Wait()
return err
}
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
instanceURL := srv.InstanceURL()
log.Infof("server will be available on: %s", instanceURL.String())
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Listen(srvListenerCfg); err != nil {
log.Fatalf("failed to bind server: %s", err)
log.Fatalf("failed to bind relay server: %s", err)
}
}()
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return fmt.Errorf("failed to create healthcheck server: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start healthcheck server: %v", err)
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
// it will block until exit signal
waitForExitSignal()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
var errs error
var shutDownErrors error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if err := srv.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
wg.Wait()
return shutDownErrors
return errs
}
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg server.Config) (*server.Server, error) {
srv, err := server.NewServer(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create relay server: %v", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners() ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cobraConfig.EnableSTUN {
for _, port := range cobraConfig.STUNPorts {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
// Close already opened listeners on failure
cleanupSTUNListeners(stunListeners)
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
}
stunListeners = append(stunListeners, listener)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {