[management] Add stable domain resolution for combined server (#5515)

The combined server was using the hostname from exposedAddress for both
singleAccountModeDomain and dnsDomain, causing fresh installs to get
the wrong domain and existing installs to break if the config changed.
 Add resolveDomains() to BaseServer that reads domain from the store:
  - Fresh install (0 accounts): uses "netbird.selfhosted" default
  - Existing install: reads persisted domain from the account in DB
  - Store errors: falls back to default safely

The combined server opts in via AutoResolveDomains flag, while the
 standalone management server is unaffected.
This commit is contained in:
Maycon Santos
2026-03-06 08:43:46 +01:00
committed by GitHub
parent a7f3ba03eb
commit 85451ab4cd
3 changed files with 134 additions and 8 deletions

View File

@@ -493,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address // Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress) _, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil { if err != nil {
@@ -507,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
mgmtSrv := mgmtServer.NewServer( mgmtSrv := mgmtServer.NewServer(
&mgmtServer.Config{ &mgmtServer.Config{
NbConfig: mgmtConfig, NbConfig: mgmtConfig,
DNSDomain: dnsDomain, DNSDomain: "",
MgmtSingleAccModeDomain: singleAccModeDomain, MgmtSingleAccModeDomain: "",
AutoResolveDomains: true,
MgmtPort: mgmtPort, MgmtPort: mgmtPort,
MgmtMetricsPort: cfg.Server.MetricsPort, MgmtMetricsPort: cfg.Server.MetricsPort,
DisableMetrics: mgmt.DisableAnonymousMetrics, DisableMetrics: mgmt.DisableAnonymousMetrics,

View File

@@ -28,9 +28,13 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. const (
// It is used for backward compatibility now. // ManagementLegacyPort is the port that was used before by the Management gRPC server.
const ManagementLegacyPort = 33073 // It is used for backward compatibility now.
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
)
type Server interface { type Server interface {
Start(ctx context.Context) error Start(ctx context.Context) error
@@ -58,6 +62,7 @@ type BaseServer struct {
mgmtMetricsPort int mgmtMetricsPort int
mgmtPort int mgmtPort int
disableLegacyManagementPort bool disableLegacyManagementPort bool
autoResolveDomains bool
proxyAuthClose func() proxyAuthClose func()
@@ -81,6 +86,7 @@ type Config struct {
DisableMetrics bool DisableMetrics bool
DisableGeoliteUpdate bool DisableGeoliteUpdate bool
UserDeleteFromIDPEnabled bool UserDeleteFromIDPEnabled bool
AutoResolveDomains bool
} }
// NewServer initializes and configures a new Server instance // NewServer initializes and configures a new Server instance
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
mgmtPort: cfg.MgmtPort, mgmtPort: cfg.MgmtPort,
disableLegacyManagementPort: cfg.DisableLegacyManagementPort, disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
mgmtMetricsPort: cfg.MgmtMetricsPort, mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
} }
} }
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.cancel = cancel s.cancel = cancel
s.errCh = make(chan error, 4) s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
}
s.PeersManager() s.PeersManager()
s.GeoLocationManager() s.GeoLocationManager()
@@ -381,6 +392,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}() }()
} }
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {
if logMsg != "" {
log.WithContext(ctx).Warnf(logMsg, args...)
}
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
}
accountsCount, err := st.GetAccountsCounter(ctx)
if err != nil {
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountsCount == 0 {
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
return
}
accountID, err := st.GetAnyAccountID(ctx)
if err != nil {
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountID == "" {
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
return
}
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
return
}
if domain == "" {
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
return
}
s.dnsDomain = domain
s.mgmtSingleAccModeDomain = domain
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) { func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID() installationID := store.GetInstallationID()
if installationID != "" { if installationID != "" {

View File

@@ -0,0 +1,63 @@
package server
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
)
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}