diff --git a/proxy/server.go b/proxy/server.go index c114919d5..43d725e4f 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -137,27 +137,12 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service } func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { - s.startTime = time.Now() + s.initializeDefaults() - // If no ID is set then one can be generated. - if s.ID == "" { - s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405") - } - // Fallback version option in case it is not set. - if s.Version == "" { - s.Version = "dev" - } - - // If no logger is specified fallback to the standard logger. - if s.Logger == nil { - s.Logger = log.StandardLogger() - } - - // Start up metrics gathering reg := prometheus.NewRegistry() s.meter = metrics.New(reg) - mgmtConn, err := s.dialManagement() + mgmtConn, err := s.setupManagementConnection(ctx) if err != nil { return err } @@ -166,49 +151,94 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.Logger.Debugf("management connection close: %v", err) } }() - s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) - go s.newManagementMappingWorker(ctx, s.mgmtClient) - - // Initialize the netbird client, this is required to build peer connections - // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) tlsConfig, err := s.configureTLS(ctx) if err != nil { return err } - // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. - s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) + accessLog := s.initializeComponents() - // Configure the authentication middleware with session validator for OIDC group checks. - s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) - - // Configure Access logs to management server. - accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) - - s.healthChecker = health.NewChecker(s.Logger, s.netbird) - - if s.DebugEndpointEnabled { - debugAddr := debugEndpointAddr(s.DebugEndpointAddress) - debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger) - if s.acme != nil { - debugHandler.SetCertStatus(s.acme) - } - s.debug = &http.Server{ - Addr: debugAddr, - Handler: debugHandler, - ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug), - } - go func() { - s.Logger.Infof("starting debug endpoint on %s", debugAddr) - if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - s.Logger.Errorf("debug endpoint error: %v", err) - } - }() + if err := s.startDebugEndpoint(); err != nil { + return err } - // Start health probe server. + if err := s.startHealthProbe(reg); err != nil { + return err + } + + listener, err := s.createHTTPSListener(addr) + if err != nil { + return err + } + + return s.serveHTTPS(ctx, listener, addr, tlsConfig, accessLog) +} + +// initializeDefaults sets default values for server configuration fields. +func (s *Server) initializeDefaults() { + s.startTime = time.Now() + + if s.ID == "" { + s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405") + } + if s.Version == "" { + s.Version = "dev" + } + if s.Logger == nil { + s.Logger = log.StandardLogger() + } +} + +// setupManagementConnection establishes the gRPC connection to the management server +// and starts the mapping worker. +func (s *Server) setupManagementConnection(ctx context.Context) (*grpc.ClientConn, error) { + mgmtConn, err := s.dialManagement() + if err != nil { + return nil, err + } + s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) + go s.newManagementMappingWorker(ctx, s.mgmtClient) + return mgmtConn, nil +} + +// initializeComponents sets up the core proxy components and returns the access logger. +func (s *Server) initializeComponents() *accesslog.Logger { + s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) + s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) + s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient) + accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) + s.healthChecker = health.NewChecker(s.Logger, s.netbird) + return accessLog +} + +// startDebugEndpoint starts the debug HTTP server if enabled. +func (s *Server) startDebugEndpoint() error { + if !s.DebugEndpointEnabled { + return nil + } + + debugAddr := debugEndpointAddr(s.DebugEndpointAddress) + debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger) + if s.acme != nil { + debugHandler.SetCertStatus(s.acme) + } + s.debug = &http.Server{ + Addr: debugAddr, + Handler: debugHandler, + ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug), + } + go func() { + s.Logger.Infof("starting debug endpoint on %s", debugAddr) + if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.Errorf("debug endpoint error: %v", err) + } + }() + return nil +} + +// startHealthProbe starts the health probe HTTP server. +func (s *Server) startHealthProbe(reg *prometheus.Registry) error { healthAddr := s.HealthAddress if healthAddr == "" { healthAddr = "localhost:8080" @@ -223,21 +253,23 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.Logger.Errorf("health probe server: %v", err) } }() + return nil +} - // Create listener with connection sniffing for HTTP redirect - // listener is closed by http.Server.ServeTLS when it exits +// createHTTPSListener creates and wraps a TCP listener for HTTPS with HTTP redirect support. +func (s *Server) createHTTPSListener(addr string) (net.Listener, error) { listener, err := net.Listen("tcp", addr) if err != nil { - return fmt.Errorf("failed to listen on %s: %w", addr, err) + return nil, fmt.Errorf("failed to listen on %s: %w", addr, err) } - - // Wrap listener to detect and redirect plain HTTP requests to HTTPS - redirectListener := &httpRedirectListener{ + return &httpRedirectListener{ Listener: listener, logger: s.Logger, - } + }, nil +} - // Start the reverse proxy HTTPS server +// serveHTTPS starts the HTTPS server and waits for it to complete or context cancellation. +func (s *Server) serveHTTPS(ctx context.Context, listener net.Listener, addr string, tlsConfig *tls.Config, accessLog *accesslog.Logger) error { s.https = &http.Server{ Addr: addr, Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))), @@ -248,7 +280,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { httpsErr := make(chan error, 1) go func() { s.Logger.Debugf("starting reverse proxy server on %s", addr) - httpsErr <- s.https.ServeTLS(redirectListener, "", "") + httpsErr <- s.https.ServeTLS(listener, "", "") }() select {