From 85451ab4cd48101ba2d68832db4902e3cca9bf1b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 6 Mar 2026 08:43:46 +0100 Subject: [PATCH] [management] Add stable domain resolution for combined server (#5515) The combined server was using the hostname from exposedAddress for both singleAccountModeDomain and dnsDomain, causing fresh installs to get the wrong domain and existing installs to break if the config changed. Add resolveDomains() to BaseServer that reads domain from the store: - Fresh install (0 accounts): uses "netbird.selfhosted" default - Existing install: reads persisted domain from the account in DB - Store errors: falls back to default safely The combined server opts in via AutoResolveDomains flag, while the standalone management server is unaffected. --- combined/cmd/root.go | 8 +-- management/internals/server/server.go | 71 ++++++++++++++++++- .../server/server_resolve_domains_test.go | 63 ++++++++++++++++ 3 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 management/internals/server/server_resolve_domains_test.go diff --git a/combined/cmd/root.go b/combined/cmd/root.go index 153260341..ea1ff908a 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -493,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) { func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { mgmt := cfg.Management - dnsDomain := mgmt.DnsDomain - singleAccModeDomain := dnsDomain - // Extract port from listen address _, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress) if err != nil { @@ -507,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* mgmtSrv := mgmtServer.NewServer( &mgmtServer.Config{ NbConfig: mgmtConfig, - DNSDomain: dnsDomain, - MgmtSingleAccModeDomain: singleAccModeDomain, + DNSDomain: "", + MgmtSingleAccModeDomain: "", + AutoResolveDomains: true, MgmtPort: mgmtPort, MgmtMetricsPort: cfg.Server.MetricsPort, DisableMetrics: mgmt.DisableAnonymousMetrics, diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 5149c338b..573983a79 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -28,9 +28,13 @@ import ( "github.com/netbirdio/netbird/version" ) -// ManagementLegacyPort is the port that was used before by the Management gRPC server. -// It is used for backward compatibility now. -const ManagementLegacyPort = 33073 +const ( + // ManagementLegacyPort is the port that was used before by the Management gRPC server. + // It is used for backward compatibility now. + ManagementLegacyPort = 33073 + // DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs. + DefaultSelfHostedDomain = "netbird.selfhosted" +) type Server interface { Start(ctx context.Context) error @@ -58,6 +62,7 @@ type BaseServer struct { mgmtMetricsPort int mgmtPort int disableLegacyManagementPort bool + autoResolveDomains bool proxyAuthClose func() @@ -81,6 +86,7 @@ type Config struct { DisableMetrics bool DisableGeoliteUpdate bool UserDeleteFromIDPEnabled bool + AutoResolveDomains bool } // NewServer initializes and configures a new Server instance @@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer { mgmtPort: cfg.MgmtPort, disableLegacyManagementPort: cfg.DisableLegacyManagementPort, mgmtMetricsPort: cfg.MgmtMetricsPort, + autoResolveDomains: cfg.AutoResolveDomains, } } @@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error { s.cancel = cancel s.errCh = make(chan error, 4) + if s.autoResolveDomains { + s.resolveDomains(srvCtx) + } + s.PeersManager() s.GeoLocationManager() @@ -381,6 +392,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene }() } +// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state. +// Fresh installs use the default self-hosted domain, while existing installs reuse the +// persisted account domain to keep addressing stable across config changes. +func (s *BaseServer) resolveDomains(ctx context.Context) { + st := s.Store() + + setDefault := func(logMsg string, args ...any) { + if logMsg != "" { + log.WithContext(ctx).Warnf(logMsg, args...) + } + s.dnsDomain = DefaultSelfHostedDomain + s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain + } + + accountsCount, err := st.GetAccountsCounter(ctx) + if err != nil { + setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain) + return + } + + if accountsCount == 0 { + s.dnsDomain = DefaultSelfHostedDomain + s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain + log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain) + return + } + + accountID, err := st.GetAnyAccountID(ctx) + if err != nil { + setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain) + return + } + + if accountID == "" { + setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain) + return + } + + domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) + if err != nil { + setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain) + return + } + + if domain == "" { + setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain) + return + } + + s.dnsDomain = domain + s.mgmtSingleAccModeDomain = domain + log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain) +} + func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { diff --git a/management/internals/server/server_resolve_domains_test.go b/management/internals/server/server_resolve_domains_test.go new file mode 100644 index 000000000..db1d7e8ca --- /dev/null +++ b/management/internals/server/server_resolve_domains_test.go @@ -0,0 +1,63 @@ +package server + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/store" +) + +func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain) + require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain) +} + +func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil) + mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil) + mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, "vpn.mycompany.com", srv.dnsDomain) + require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain) +} + +func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed")) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain) + require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain) +}