Compare commits

...

1 Commits

Author SHA1 Message Date
Viktor Liu
e3d1b9ca88 Apply global search domains 2025-06-04 17:35:56 +02:00

View File

@@ -41,14 +41,25 @@ const (
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
interfaceConfigRegistrationEnabledKey = "RegistrationEnabled"
interfaceConfigDisableDynamicUpdateKey = "DisableDynamicUpdate"
globalDNSPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
globalSearchListKey = "SearchList"
interfaceMetricKey = "InterfaceMetric"
vpnInterfaceMetric = 10
// RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
guid string
routingAll bool
gpo bool
useGlobalSearchDomains bool
previousGlobalSearchList string
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -68,8 +79,9 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
}
return &registryConfigurator{
guid: guid,
gpo: useGPO,
guid: guid,
gpo: useGPO,
useGlobalSearchDomains: true,
}, nil
}
@@ -78,6 +90,14 @@ func (r *registryConfigurator) supportCustomPort() bool {
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if err := r.preventDNSRegistration(); err != nil {
log.Warnf("failed to prevent DNS registration: %v", err)
}
if err := r.setInterfaceMetric(); err != nil {
log.Warnf("failed to set interface metric: %v", err)
}
if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
return fmt.Errorf("add dns setup: %w", err)
@@ -90,7 +110,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
@@ -126,6 +149,97 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil
}
func (r *registryConfigurator) preventDNSRegistration() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
if err := regKey.SetDWordValue(interfaceConfigRegistrationEnabledKey, 0); err != nil {
log.Warnf("failed to set RegistrationEnabled: %v", err)
}
if err := regKey.SetDWordValue(interfaceConfigDisableDynamicUpdateKey, 1); err != nil {
log.Warnf("failed to set DisableDynamicUpdate: %v", err)
}
log.Infof("disabled DNS registration for NetBird interface")
return nil
}
func (r *registryConfigurator) setInterfaceMetric() error {
regKeyPath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + r.guid
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("open interface key: %w", err)
}
defer closer(regKey)
if err := regKey.SetDWordValue(interfaceMetricKey, vpnInterfaceMetric); err != nil {
return fmt.Errorf("set interface metric: %w", err)
}
return nil
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("set interface search domains: %w", err)
}
if r.useGlobalSearchDomains && len(domains) > 0 {
if err := r.updateGlobalSearchDomains(domains); err != nil {
return fmt.Errorf("update global search domains: %w", err)
}
}
log.Infof("updated search domains: %s (interface-specific: true, global: %v)", domains, r.useGlobalSearchDomains)
return nil
}
func (r *registryConfigurator) updateGlobalSearchDomains(domains []string) error {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, globalDNSPath, registry.QUERY_VALUE|registry.SET_VALUE)
if err != nil {
return fmt.Errorf("open global DNS parameters: %w", err)
}
defer closer(regKey)
if r.previousGlobalSearchList == "" {
currentList, _, err := regKey.GetStringValue(globalSearchListKey)
if err == nil {
r.previousGlobalSearchList = currentList
}
}
existingDomains := []string{}
if r.previousGlobalSearchList != "" {
existingDomains = strings.Split(r.previousGlobalSearchList, ",")
}
domainMap := make(map[string]bool)
for _, d := range existingDomains {
domainMap[strings.TrimSpace(d)] = true
}
for _, d := range domains {
domainMap[strings.TrimSpace(d)] = true
}
mergedDomains := make([]string, 0, len(domainMap))
for d := range domainMap {
if d != "" {
mergedDomains = append(mergedDomains, d)
}
}
if err := regKey.SetStringValue(globalSearchListKey, strings.Join(mergedDomains, ",")); err != nil {
return fmt.Errorf("set global search list: %w", err)
}
log.Infof("updated global DNS search list with NetBird domains")
return nil
}
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err)
@@ -211,14 +325,6 @@ func (r *registryConfigurator) flushDNSCache() error {
return nil
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("update search domains: %w", err)
}
log.Infof("updated search domains: %s", domains)
return nil
}
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
@@ -255,6 +361,12 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
}
func (r *registryConfigurator) restoreHostDNS() error {
if r.previousGlobalSearchList != "" && r.useGlobalSearchDomains {
if err := r.restoreGlobalSearchDomains(); err != nil {
log.Errorf("failed to restore global search domains: %v", err)
}
}
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("remove dns match policies: %s", err)
}
@@ -263,6 +375,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err)
}
if err := r.restoreDNSRegistration(); err != nil {
log.Warnf("failed to restore DNS registration: %v", err)
}
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
@@ -270,6 +386,34 @@ func (r *registryConfigurator) restoreHostDNS() error {
return nil
}
func (r *registryConfigurator) restoreGlobalSearchDomains() error {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, globalDNSPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("open global DNS parameters: %w", err)
}
defer closer(regKey)
if err := regKey.SetStringValue(globalSearchListKey, r.previousGlobalSearchList); err != nil {
return fmt.Errorf("restore global search list: %w", err)
}
log.Infof("restored global DNS search list")
return nil
}
func (r *registryConfigurator) restoreDNSRegistration() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
regKey.DeleteValue(interfaceConfigRegistrationEnabledKey)
regKey.DeleteValue(interfaceConfigDisableDynamicUpdateKey)
return nil
}
func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {