diff --git a/combined/cmd/root.go b/combined/cmd/root.go index 153260341..ea1ff908a 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -493,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) { func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { mgmt := cfg.Management - dnsDomain := mgmt.DnsDomain - singleAccModeDomain := dnsDomain - // Extract port from listen address _, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress) if err != nil { @@ -507,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* mgmtSrv := mgmtServer.NewServer( &mgmtServer.Config{ NbConfig: mgmtConfig, - DNSDomain: dnsDomain, - MgmtSingleAccModeDomain: singleAccModeDomain, + DNSDomain: "", + MgmtSingleAccModeDomain: "", + AutoResolveDomains: true, MgmtPort: mgmtPort, MgmtMetricsPort: cfg.Server.MetricsPort, DisableMetrics: mgmt.DisableAnonymousMetrics, diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 5149c338b..573983a79 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -28,9 +28,13 @@ import ( "github.com/netbirdio/netbird/version" ) -// ManagementLegacyPort is the port that was used before by the Management gRPC server. -// It is used for backward compatibility now. -const ManagementLegacyPort = 33073 +const ( + // ManagementLegacyPort is the port that was used before by the Management gRPC server. + // It is used for backward compatibility now. + ManagementLegacyPort = 33073 + // DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs. + DefaultSelfHostedDomain = "netbird.selfhosted" +) type Server interface { Start(ctx context.Context) error @@ -58,6 +62,7 @@ type BaseServer struct { mgmtMetricsPort int mgmtPort int disableLegacyManagementPort bool + autoResolveDomains bool proxyAuthClose func() @@ -81,6 +86,7 @@ type Config struct { DisableMetrics bool DisableGeoliteUpdate bool UserDeleteFromIDPEnabled bool + AutoResolveDomains bool } // NewServer initializes and configures a new Server instance @@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer { mgmtPort: cfg.MgmtPort, disableLegacyManagementPort: cfg.DisableLegacyManagementPort, mgmtMetricsPort: cfg.MgmtMetricsPort, + autoResolveDomains: cfg.AutoResolveDomains, } } @@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error { s.cancel = cancel s.errCh = make(chan error, 4) + if s.autoResolveDomains { + s.resolveDomains(srvCtx) + } + s.PeersManager() s.GeoLocationManager() @@ -381,6 +392,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene }() } +// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state. +// Fresh installs use the default self-hosted domain, while existing installs reuse the +// persisted account domain to keep addressing stable across config changes. +func (s *BaseServer) resolveDomains(ctx context.Context) { + st := s.Store() + + setDefault := func(logMsg string, args ...any) { + if logMsg != "" { + log.WithContext(ctx).Warnf(logMsg, args...) + } + s.dnsDomain = DefaultSelfHostedDomain + s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain + } + + accountsCount, err := st.GetAccountsCounter(ctx) + if err != nil { + setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain) + return + } + + if accountsCount == 0 { + s.dnsDomain = DefaultSelfHostedDomain + s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain + log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain) + return + } + + accountID, err := st.GetAnyAccountID(ctx) + if err != nil { + setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain) + return + } + + if accountID == "" { + setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain) + return + } + + domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) + if err != nil { + setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain) + return + } + + if domain == "" { + setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain) + return + } + + s.dnsDomain = domain + s.mgmtSingleAccModeDomain = domain + log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain) +} + func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { diff --git a/management/internals/server/server_resolve_domains_test.go b/management/internals/server/server_resolve_domains_test.go new file mode 100644 index 000000000..db1d7e8ca --- /dev/null +++ b/management/internals/server/server_resolve_domains_test.go @@ -0,0 +1,63 @@ +package server + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/store" +) + +func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain) + require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain) +} + +func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil) + mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil) + mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, "vpn.mycompany.com", srv.dnsDomain) + require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain) +} + +func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed")) + + srv := NewServer(&Config{NbConfig: &nbconfig.Config{}}) + Inject[store.Store](srv, mockStore) + + srv.resolveDomains(context.Background()) + + require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain) + require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain) +}