[management] Add custom dns zones (#4849)

This commit is contained in:
Bethuel Mmbaga
2026-01-16 10:12:05 +01:00
committed by GitHub
parent 291e640b28
commit 067c77e49e
36 changed files with 4837 additions and 63 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -3,6 +3,7 @@ package controller
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -14,6 +15,7 @@ type Repository interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
}
type repository struct {
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -0,0 +1,13 @@
package zones
import (
"context"
)
type Manager interface {
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
}

View File

@@ -0,0 +1,161 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiZones := make([]*api.Zone, 0, len(allZones))
for _, zone := range allZones {
apiZones = append(apiZones, zone.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,229 @@
package manager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
dnsDomain string
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
}
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err
}
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
}
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.CreateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to create zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
return zone, nil
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err)
}
if zone.Domain != updatedZone.Domain {
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
}
zone.Name = updatedZone.Name
zone.Enabled = updatedZone.Enabled
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
zone.DistributionGroups = updatedZone.DistributionGroups
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.UpdateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to update zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return zone, nil
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get records: %w", err)
}
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone dns records: %w", err)
}
err = transaction.DeleteZone(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
for _, record := range records {
eventsToStore = append(eventsToStore, func() {
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
})
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
})
return nil
})
if err != nil {
return err
}
for _, event := range eventsToStore {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
if m.dnsDomain != "" && m.dnsDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if settings.DNSDomain != "" && settings.DNSDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
return nil
}

View File

@@ -0,0 +1,553 @@
package manager
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testZoneID = "test-zone-id"
testGroupID = "test-group-id"
testDNSDomain = "netbird.selfhosted"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone1)
require.NoError(t, err)
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, zone1.ID, result[0].ID)
assert.Equal(t, zone2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID)
assert.Equal(t, zone.Name, result.Name)
assert.Equal(t, zone.Domain, result.Domain)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneCreated, activityID)
}
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, inputZone.Name, result.Name)
assert.Equal(t, inputZone.Domain, result.Domain)
assert.Equal(t, inputZone.Enabled, result.Enabled)
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{"invalid-group"},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "duplicate.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.AlreadyExists, s.Type())
})
t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
account, err := testStore.GetAccount(ctx, testAccountID)
require.NoError(t, err)
account.Settings.DNSDomain = "peers.example.com"
err = testStore.SaveAccount(ctx, account)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: "peers.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: testDNSDomain,
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
}
func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Updated Name",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingZone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneUpdated, activityID)
}
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedZone.Name, result.Name)
assert.Equal(t, updatedZone.Enabled, result.Enabled)
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Test Zone",
Domain: "different.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain cannot be updated")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: "non-existent-zone",
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background()
t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 3, storeEventCallCount)
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.NoError(t, err)
assert.Empty(t, zoneRecords)
})
t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, zone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneDeleted, activityID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,13 @@
package records
import (
"context"
)
type Manager interface {
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
}

View File

@@ -0,0 +1,191 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
for _, record := range allRecords {
apiRecords = append(apiRecords, record.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
record.ID = recordID
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,236 @@
package manager
import (
"context"
"fmt"
"strings"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
err = validateRecordConflicts(ctx, transaction, zone, record)
if err != nil {
return err
}
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to create dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
var record *records.Record
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
record.Name = updatedRecord.Name
record.Type = updatedRecord.Type
record.Content = updatedRecord.Content
record.TTL = updatedRecord.TTL
if hasChanges {
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
return err
}
}
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to update dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record
var zone *zones.Zone
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to delete dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// validateRecordConflicts checks for duplicate records and CNAME conflicts
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
}
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
if err != nil {
return fmt.Errorf("failed to check existing records: %w", err)
}
for _, existing := range existingRecords {
if existing.ID == record.ID {
continue
}
if existing.Type == record.Type && existing.Content == record.Content {
return status.Errorf(status.AlreadyExists, "identical record already exists")
}
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
return status.Errorf(status.InvalidArgument,
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
}
}
return nil
}

View File

@@ -0,0 +1,573 @@
package manager
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testRecordID = "test-record-id"
testGroupID = "test-group-id"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err = testStore.CreateZone(ctx, zone)
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, record1.ID, result[0].ID)
assert.Equal(t, record2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.Equal(t, record.ID, result.ID)
assert.Equal(t, record.Name, result.Name)
assert.Equal(t, record.Type, result.Type)
assert.Equal(t, record.Content, result.Content)
assert.Equal(t, record.TTL, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, zone.ID, result.ZoneID)
assert.Equal(t, inputRecord.Name, result.Name)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
assert.Equal(t, inputRecord.TTL, result.TTL)
})
t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "ipv6.example.com",
Type: records.RecordTypeAAAA,
Content: "2001:db8::1",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "www.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.different.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "does not belong to zone")
})
t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "already exists")
})
}
func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100", // Changed IP
TTL: 600, // Changed TTL
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingRecord.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordUpdated, activityID)
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedRecord.Content, result.Content)
assert.Equal(t, updatedRecord.TTL, result.TTL)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: existingRecord.Name,
Type: existingRecord.Type,
Content: existingRecord.Content,
TTL: 600, // Only TTL changed
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, 600, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: "non-existent-record",
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: record2.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
}
func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, record.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordDeleted, activityID)
}
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,129 @@
package records
import (
"errors"
"net"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type RecordType string
const (
RecordTypeA RecordType = "A"
RecordTypeAAAA RecordType = "AAAA"
RecordTypeCNAME RecordType = "CNAME"
)
type Record struct {
AccountID string `gorm:"index"`
ZoneID string `gorm:"index"`
ID string `gorm:"primaryKey"`
Name string
Type RecordType
Content string
TTL int
}
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
return &Record{
ID: xid.New().String(),
AccountID: accountID,
ZoneID: zoneID,
Name: name,
Type: recordType,
Content: content,
TTL: ttl,
}
}
func (r *Record) ToAPIResponse() *api.DNSRecord {
recordType := api.DNSRecordType(r.Type)
return &api.DNSRecord{
Id: r.ID,
Name: r.Name,
Type: recordType,
Content: r.Content,
Ttl: r.TTL,
}
}
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
r.Name = req.Name
r.Type = RecordType(req.Type)
r.Content = req.Content
r.TTL = req.Ttl
}
func (r *Record) Validate() error {
if r.Name == "" {
return errors.New("record name is required")
}
if !util.IsValidDomain(r.Name) {
return errors.New("invalid record name format")
}
if r.Type == "" {
return errors.New("record type is required")
}
switch r.Type {
case RecordTypeA:
if err := validateIPv4(r.Content); err != nil {
return err
}
case RecordTypeAAAA:
if err := validateIPv6(r.Content); err != nil {
return err
}
case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) {
return errors.New("invalid CNAME record format")
}
default:
return errors.New("invalid record type, must be A, AAAA, or CNAME")
}
if r.TTL < 0 {
return errors.New("TTL cannot be negative")
}
return nil
}
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
return map[string]any{
"name": r.Name,
"type": string(r.Type),
"content": r.Content,
"ttl": r.TTL,
"zone_id": zoneID,
"zone_name": zoneName,
}
}
func validateIPv4(content string) error {
if content == "" {
return errors.New("A record is required") //nolint:staticcheck
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() == nil {
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
}
return nil
}
func validateIPv6(content string) error {
if content == "" {
return errors.New("AAAA record is required")
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() != nil {
return errors.New("AAAA record must be a valid IPv6 address")
}
return nil
}

View File

@@ -0,0 +1,89 @@
package zones
import (
"errors"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type Zone struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string
Enabled bool
EnableSearchDomain bool
DistributionGroups []string `gorm:"serializer:json"`
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
}
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
return &Zone{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
Enabled: enabled,
EnableSearchDomain: enableSearchDomain,
DistributionGroups: distributionGroups,
}
}
func (z *Zone) ToAPIResponse() *api.Zone {
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
for _, record := range z.Records {
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
apiRecords = append(apiRecords, *apiRecord)
}
}
return &api.Zone{
DistributionGroups: z.DistributionGroups,
Domain: z.Domain,
EnableSearchDomain: z.EnableSearchDomain,
Enabled: z.Enabled,
Id: z.ID,
Name: z.Name,
Records: apiRecords,
}
}
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
z.Name = req.Name
z.Domain = req.Domain
z.EnableSearchDomain = req.EnableSearchDomain
z.DistributionGroups = req.DistributionGroups
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
z.Enabled = enabled
}
func (z *Zone) Validate() error {
if z.Name == "" {
return errors.New("zone name is required")
}
if len(z.Name) > 255 {
return errors.New("zone name exceeds maximum length of 255 characters")
}
if !util.IsValidDomain(z.Domain) {
return errors.New("invalid zone domain format")
}
if len(z.DistributionGroups) == 0 {
return errors.New("at least one distribution group is required")
}
return nil
}
func (z *Zone) EventMeta() map[string]any {
return map[string]any{"name": z.Name, "domain": z.Domain}
}

View File

@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}

View File

@@ -8,6 +8,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}
func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
})
}
func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
})
}

View File

@@ -374,9 +374,10 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
NonAuthoritative: zone.NonAuthoritative,
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{

View File

@@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
@@ -388,7 +388,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -402,6 +402,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
}
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
@@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
}
}
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
ctx := context.Background()
err = manager.Store.CreateZone(ctx, &zones.Zone{
ID: "test-zone-id",
AccountID: accountID,
Name: "Test Zone",
Domain: "custom.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{},
})
require.NoError(t, err, "unable to create custom DNS zone")
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
DNSDomain: "custom.example.com",
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
}
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string

View File

@@ -187,6 +187,14 @@ const (
IdentityProviderUpdated Activity = 94
IdentityProviderDeleted Activity = 95
DNSZoneCreated Activity = 96
DNSZoneUpdated Activity = 97
DNSZoneDeleted Activity = 98
DNSRecordCreated Activity = 99
DNSRecordUpdated Activity = 100
DNSRecordDeleted Activity = 101
AccountDeleted Activity = 99999
)
@@ -303,6 +311,14 @@ var activityMap = map[Activity]Code{
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
DNSZoneCreated: {"DNS zone created", "dns.zone.create"},
DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"},
DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"},
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -15,7 +15,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
@@ -56,7 +59,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -138,6 +141,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -298,8 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -27,6 +27,8 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -123,6 +125,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&zones.Zone{}, &records.Record{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -4179,3 +4182,184 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
return userID, nil
}
func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error {
result := s.db.Create(zone)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create zone to store")
}
return nil
}
func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
result := s.db.Select("*").Save(zone)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update zone to store")
}
return nil
}
func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error {
result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete zone from store")
}
if result.RowsAffected == 0 {
return status.NewZoneNotFoundError(zoneID)
}
return nil
}
func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var zone *zones.Zone
result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewZoneNotFoundError(zoneID)
}
log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone from store")
}
return zone, nil
}
func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) {
var zone *zones.Zone
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewZoneNotFoundError(domain)
}
log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone by domain from store")
}
return zone, nil
}
func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var zones []*zones.Zone
result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zones from store")
}
return zones, nil
}
func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error {
result := s.db.Create(record)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create dns record to store")
}
return nil
}
func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error {
result := s.db.Select("*").Save(record)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update dns record to store")
}
return nil
}
func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error {
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete dns record from store")
}
if result.RowsAffected == 0 {
return status.NewDNSRecordNotFoundError(recordID)
}
return nil
}
func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record *records.Record
result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewDNSRecordNotFoundError(recordID)
}
log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get dns record from store")
}
return record, nil
}
func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var recordsList []*records.Record
result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone dns records from store")
}
return recordsList, nil
}
func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var recordsList []*records.Record
result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store")
}
return recordsList, nil
}
func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error {
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete zone dns records from store")
}
return nil
}

View File

@@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
}
func TestSqlStore_CreateZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, zone.ID, savedZone.ID)
assert.Equal(t, zone.Name, savedZone.Name)
assert.Equal(t, zone.Domain, savedZone.Domain)
assert.Equal(t, zone.Enabled, savedZone.Enabled)
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
}
func TestSqlStore_GetZoneByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
tests := []struct {
name string
accountID string
zoneID string
expectError bool
}{
{
name: "retrieve existing zone",
accountID: accountID,
zoneID: zone.ID,
expectError: false,
},
{
name: "retrieve non-existing zone",
accountID: accountID,
zoneID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty zone ID",
accountID: accountID,
zoneID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, savedZone)
} else {
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, tt.zoneID, savedZone.ID)
}
})
}
}
func TestSqlStore_GetAccountZones(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone1)
require.NoError(t, err)
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
err = store.CreateZone(context.Background(), zone2)
require.NoError(t, err)
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotNil(t, allZones)
assert.GreaterOrEqual(t, len(allZones), 2)
zoneIDs := make(map[string]bool)
for _, z := range allZones {
zoneIDs[z.ID] = true
}
assert.True(t, zoneIDs[zone1.ID])
assert.True(t, zoneIDs[zone2.ID])
}
func TestSqlStore_GetZoneByDomain(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
tests := []struct {
name string
accountID string
domain string
expectError bool
errorType status.Type
}{
{
name: "retrieve existing zone by domain",
accountID: accountID,
domain: "example.com",
expectError: false,
},
{
name: "retrieve non-existing zone domain",
accountID: accountID,
domain: "non-existing.com",
expectError: true,
errorType: status.NotFound,
},
{
name: "retrieve with empty domain",
accountID: accountID,
domain: "",
expectError: true,
errorType: status.NotFound,
},
{
name: "retrieve with different account ID",
accountID: otherAccountID,
domain: "example.com",
expectError: true,
errorType: status.NotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, tt.errorType, sErr.Type())
require.Nil(t, savedZone)
} else {
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, tt.domain, savedZone.Domain)
assert.Equal(t, zone.ID, savedZone.ID)
assert.Equal(t, zone.Name, savedZone.Name)
}
})
}
}
func TestSqlStore_UpdateZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
zone.Name = "Updated Zone"
zone.Domain = "updated.com"
zone.Enabled = false
zone.EnableSearchDomain = true
zone.DistributionGroups = []string{"group2", "group3"}
err = store.UpdateZone(context.Background(), zone)
require.NoError(t, err)
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, updatedZone)
assert.Equal(t, "Updated Zone", updatedZone.Name)
assert.Equal(t, "updated.com", updatedZone.Domain)
assert.False(t, updatedZone.Enabled)
assert.True(t, updatedZone.EnableSearchDomain)
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
}
func TestSqlStore_DeleteZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
err = store.DeleteZone(context.Background(), accountID, zone.ID)
require.NoError(t, err)
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.Error(t, err)
require.Nil(t, deletedZone)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
}
func TestSqlStore_CreateDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.NoError(t, err)
require.NotNil(t, savedRecord)
assert.Equal(t, record.ID, savedRecord.ID)
assert.Equal(t, record.Name, savedRecord.Name)
assert.Equal(t, record.Type, savedRecord.Type)
assert.Equal(t, record.Content, savedRecord.Content)
assert.Equal(t, record.TTL, savedRecord.TTL)
assert.Equal(t, zone.ID, savedRecord.ZoneID)
}
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
tests := []struct {
name string
accountID string
zoneID string
recordID string
expectError bool
}{
{
name: "retrieve existing record",
accountID: accountID,
zoneID: zone.ID,
recordID: record.ID,
expectError: false,
},
{
name: "retrieve non-existing record",
accountID: accountID,
zoneID: zone.ID,
recordID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty record ID",
accountID: accountID,
zoneID: zone.ID,
recordID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, savedRecord)
} else {
require.NoError(t, err)
require.NotNil(t, savedRecord)
assert.Equal(t, tt.recordID, savedRecord.ID)
}
})
}
}
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), recordA)
require.NoError(t, err)
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
err = store.CreateDNSRecord(context.Background(), recordAAAA)
require.NoError(t, err)
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
err = store.CreateDNSRecord(context.Background(), recordCNAME)
require.NoError(t, err)
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, allRecords)
assert.Equal(t, 3, len(allRecords))
recordIDs := make(map[string]bool)
for _, r := range allRecords {
recordIDs[r.ID] = true
}
assert.True(t, recordIDs[recordA.ID])
assert.True(t, recordIDs[recordAAAA.ID])
assert.True(t, recordIDs[recordCNAME.ID])
}
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record1)
require.NoError(t, err)
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
err = store.CreateDNSRecord(context.Background(), record2)
require.NoError(t, err)
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
err = store.CreateDNSRecord(context.Background(), record3)
require.NoError(t, err)
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
require.NoError(t, err)
require.NotNil(t, recordsByName)
assert.Equal(t, 2, len(recordsByName))
for _, r := range recordsByName {
assert.Equal(t, "www.example.com", r.Name)
}
}
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
record.Name = "api.example.com"
record.Content = "192.168.1.100"
record.TTL = 600
err = store.UpdateDNSRecord(context.Background(), record)
require.NoError(t, err)
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.NoError(t, err)
require.NotNil(t, updatedRecord)
assert.Equal(t, "api.example.com", updatedRecord.Name)
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
assert.Equal(t, 600, updatedRecord.TTL)
}
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
require.NoError(t, err)
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.Error(t, err)
require.Nil(t, deletedRecord)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
}
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record1)
require.NoError(t, err)
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
err = store.CreateDNSRecord(context.Background(), record2)
require.NoError(t, err)
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 2, len(allRecords))
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
require.NoError(t, err)
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 0, len(remainingRecords))
}

View File

@@ -23,6 +23,8 @@ import (
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
@@ -209,6 +211,21 @@ type Store interface {
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
SetFieldEncrypt(enc *crypt.FieldEncrypt)
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
CreateZone(ctx context.Context, zone *zones.Zone) error
UpdateZone(ctx context.Context, zone *zones.Zone) error
DeleteZone(ctx context.Context, accountID, zoneID string) error
GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error)
GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error)
GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error)
CreateDNSRecord(ctx context.Context, record *records.Record) error
UpdateDNSRecord(ctx context.Context, record *records.Record) error
DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error
GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error)
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
}
const (

View File

@@ -18,6 +18,8 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -150,17 +152,16 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
groupListMap := a.GetPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
@@ -274,6 +275,7 @@ func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
@@ -294,6 +296,8 @@ func (a *Account) GetPeerNetworkMap(
}
}
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
@@ -307,7 +311,7 @@ func (a *Account) GetPeerNetworkMap(
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
@@ -323,6 +327,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
@@ -330,6 +335,10 @@ func (a *Account) GetPeerNetworkMap(
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
@@ -1881,3 +1890,66 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
return filteredRecords
}
// filterPeerAppliedZones filters account zones based on the peer's group membership
func filterPeerAppliedZones(ctx context.Context, accountZones []*zones.Zone, peerGroups LookupMap) []nbdns.CustomZone {
var customZones []nbdns.CustomZone
if len(peerGroups) == 0 {
return customZones
}
for _, zone := range accountZones {
if !zone.Enabled || len(zone.Records) == 0 {
continue
}
hasAccess := false
for _, distGroupID := range zone.DistributionGroups {
if _, found := peerGroups[distGroupID]; found {
hasAccess = true
break
}
}
if !hasAccess {
continue
}
simpleRecords := make([]nbdns.SimpleRecord, 0, len(zone.Records))
for _, record := range zone.Records {
var recordType int
rData := record.Content
switch record.Type {
case records.RecordTypeA:
recordType = int(dns.TypeA)
case records.RecordTypeAAAA:
recordType = int(dns.TypeAAAA)
case records.RecordTypeCNAME:
recordType = int(dns.TypeCNAME)
rData = dns.Fqdn(record.Content)
default:
log.WithContext(ctx).Warnf("unknown DNS record type %s for record %s", record.Type, record.ID)
continue
}
simpleRecords = append(simpleRecords, nbdns.SimpleRecord{
Name: dns.Fqdn(record.Name),
Type: recordType,
Class: nbdns.DefaultClass,
TTL: record.TTL,
RData: rData,
})
}
customZones = append(customZones, nbdns.CustomZone{
Domain: dns.Fqdn(zone.Domain),
Records: simpleRecords,
SearchDomainDisabled: !zone.EnableSearchDomain,
NonAuthoritative: true,
})
}
return customZones
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -1425,3 +1427,515 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
})
}
}
func Test_filterPeerAppliedZones(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
accountZones []*zones.Zone
peerGroups LookupMap
expected []nbdns.CustomZone
}{
{
name: "empty peer groups returns empty custom zones",
accountZones: []*zones.Zone{},
peerGroups: LookupMap{},
expected: []nbdns.CustomZone{},
},
{
name: "peer has access to zone with A record",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "peer has access to zone with search domain enabled",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "internal.local",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "api.internal.local",
Type: records.RecordTypeA,
Content: "10.0.0.1",
TTL: 600,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "internal.local.",
Records: []nbdns.SimpleRecord{
{
Name: "api.internal.local.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 600,
RData: "10.0.0.1",
},
},
SearchDomainDisabled: false,
},
},
},
{
name: "peer has no access to zone",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "private.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group2"},
Records: []*records.Record{
{
ID: "record1",
Name: "secret.private.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{},
},
{
name: "disabled zone is filtered out",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "disabled.com",
Enabled: false,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.disabled.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{},
},
{
name: "zone with no records is filtered out",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "empty.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{},
},
{
name: "peer has access via multiple groups",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "multi.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1", "group2", "group3"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.multi.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group2": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "multi.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.multi.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "multiple zones with mixed access",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "allowed.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.allowed.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
{
ID: "zone2",
Domain: "denied.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group2"},
Records: []*records.Record{
{
ID: "record2",
Name: "www.denied.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "allowed.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.allowed.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "zone with multiple record types",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "mixed.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.mixed.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
{
ID: "record2",
Name: "ipv6.mixed.com",
Type: records.RecordTypeAAAA,
Content: "2001:db8::1",
TTL: 600,
},
{
ID: "record3",
Name: "alias.mixed.com",
Type: records.RecordTypeCNAME,
Content: "www.mixed.com",
TTL: 900,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "mixed.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.mixed.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
{
Name: "ipv6.mixed.com.",
Type: int(dns.TypeAAAA),
Class: nbdns.DefaultClass,
TTL: 600,
RData: "2001:db8::1",
},
{
Name: "alias.mixed.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 900,
RData: "www.mixed.com.",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "multiple zones both accessible",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "first.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.first.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
{
ID: "zone2",
Domain: "second.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record2",
Name: "www.second.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 600,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "first.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.first.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: false,
},
{
Domain: "second.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.second.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 600,
RData: "192.168.1.2",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "zone with multiple records of same type",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "multi-a.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.multi-a.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
{
ID: "record2",
Name: "www.multi-a.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "multi-a.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.multi-a.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
{
Name: "www.multi-a.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.2",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "peer in multiple groups accessing different zones",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "zone1.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.zone1.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
{
ID: "zone2",
Domain: "zone2.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group2"},
Records: []*records.Record{
{
ID: "record2",
Name: "www.zone2.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}, "group2": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "zone1.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.zone1.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
{
Domain: "zone2.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.zone2.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.2",
},
},
SearchDomainDisabled: true,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterPeerAppliedZones(ctx, tt.accountZones, tt.peerGroups)
require.Equal(t, len(tt.expected), len(result), "number of custom zones should match")
for i, expectedZone := range tt.expected {
assert.Equal(t, expectedZone.Domain, result[i].Domain, "domain should match")
assert.Equal(t, expectedZone.SearchDomainDisabled, result[i].SearchDomainDisabled, "search domain disabled flag should match")
assert.Equal(t, len(expectedZone.Records), len(result[i].Records), "number of records should match")
for j, expectedRecord := range expectedZone.Records {
assert.Equal(t, expectedRecord.Name, result[i].Records[j].Name, "record name should match")
assert.Equal(t, expectedRecord.Type, result[i].Records[j].Type, "record type should match")
assert.Equal(t, expectedRecord.Class, result[i].Records[j].Class, "record class should match")
assert.Equal(t, expectedRecord.TTL, result[i].Records[j].TTL, "record TTL should match")
assert.Equal(t, expectedRecord.RData, result[i].Records[j].RData, "record RData should match")
}
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -25,11 +26,12 @@ func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {

View File

@@ -70,13 +70,13 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -115,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -124,7 +124,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
for range b.N {
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
for _, peerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -177,7 +177,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -185,7 +185,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
err = builder.OnPeerAddedIncremental(account, newPeerID)
require.NoError(t, err, "error adding peer to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -240,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -250,7 +250,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -317,7 +317,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -325,7 +325,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
err = builder.OnPeerAddedIncremental(account, newRouterID)
require.NoError(t, err, "error adding router to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -402,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -412,7 +412,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newRouterID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -458,7 +458,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -466,7 +466,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
err = builder.OnPeerDeleted(account, deletedPeerID)
require.NoError(t, err, "error deleting peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -537,7 +537,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -545,7 +545,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
err = builder.OnPeerDeleted(account, deletedRouterID)
require.NoError(t, err, "error deleting routing peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -597,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -607,7 +607,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerDeleted(account, deletedPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -944,7 +944,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T
time.Sleep(100 * time.Millisecond)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(networkMap)

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -1033,7 +1034,7 @@ func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
}
func (b *NetworkMapBuilder) GetPeerNetworkMap(
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone,
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
@@ -1057,7 +1058,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
return &NetworkMap{Network: account.Network.Copy()}
}
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
@@ -1074,8 +1075,8 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
}
func (b *NetworkMapBuilder) assembleNetworkMap(
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{},
) *NetworkMap {
var peersToConnect []*nbpeer.Peer
@@ -1125,13 +1126,26 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
}
finalDNSConfig := *dnsConfig
if finalDNSConfig.ServiceEnable && customZone.Domain != "" {
if finalDNSConfig.ServiceEnable {
var zones []nbdns.CustomZone
records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: customZone.Domain,
Records: records,
})
peerGroupsSlice := b.cache.peerToGroups[peer.ID]
peerGroups := make(LookupMap, len(peerGroupsSlice))
for _, groupID := range peerGroupsSlice {
peerGroups[groupID] = struct{}{}
}
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
finalDNSConfig.CustomZones = zones
}

View File

@@ -1,5 +1,9 @@
package util
import "regexp"
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
// Difference returns the elements in `a` that aren't in `b`.
func Difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
@@ -50,3 +54,10 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
}
return false
}
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}

View File

@@ -59,9 +59,13 @@ type Client struct {
Routes *RoutesAPI
// DNS NetBird DNS APIs
// see more: https://docs.netbird.io/api/resources/routes
// see more: https://docs.netbird.io/api/resources/dns
DNS *DNSAPI
// DNSZones NetBird DNS Zones APIs
// see more: https://docs.netbird.io/api/resources/dns-zones
DNSZones *DNSZonesAPI
// GeoLocation NetBird Geo Location APIs
// see more: https://docs.netbird.io/api/resources/geo-locations
GeoLocation *GeoLocationAPI
@@ -113,6 +117,7 @@ func (c *Client) initialize() {
c.Networks = &NetworksAPI{c}
c.Routes = &RoutesAPI{c}
c.DNS = &DNSAPI{c}
c.DNSZones = &DNSZonesAPI{c}
c.GeoLocation = &GeoLocationAPI{c}
c.Events = &EventsAPI{c}
}

View File

@@ -0,0 +1,170 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// DNSZonesAPI APIs for DNS Zones Management, do not use directly
type DNSZonesAPI struct {
c *Client
}
// ListZones list all DNS zones
// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-zones
func (a *DNSZonesAPI) ListZones(ctx context.Context) ([]api.Zone, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Zone](resp)
return ret, err
}
// GetZone get DNS zone info
// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-zone
func (a *DNSZonesAPI) GetZone(ctx context.Context, zoneID string) (*api.Zone, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Zone](resp)
return &ret, err
}
// CreateZone create new DNS zone
// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-zone
func (a *DNSZonesAPI) CreateZone(ctx context.Context, request api.PostApiDnsZonesJSONRequestBody) (*api.Zone, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Zone](resp)
return &ret, err
}
// UpdateZone update DNS zone info
// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-zone
func (a *DNSZonesAPI) UpdateZone(ctx context.Context, zoneID string, request api.PutApiDnsZonesZoneIdJSONRequestBody) (*api.Zone, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Zone](resp)
return &ret, err
}
// DeleteZone delete DNS zone
// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-zone
func (a *DNSZonesAPI) DeleteZone(ctx context.Context, zoneID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ListRecords list all DNS records in a zone
// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-records
func (a *DNSZonesAPI) ListRecords(ctx context.Context, zoneID string) ([]api.DNSRecord, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.DNSRecord](resp)
return ret, err
}
// GetRecord get DNS record info
// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-record
func (a *DNSZonesAPI) GetRecord(ctx context.Context, zoneID, recordID string) (*api.DNSRecord, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSRecord](resp)
return &ret, err
}
// CreateRecord create new DNS record in a zone
// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-record
func (a *DNSZonesAPI) CreateRecord(ctx context.Context, zoneID string, request api.PostApiDnsZonesZoneIdRecordsJSONRequestBody) (*api.DNSRecord, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones/"+zoneID+"/records", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSRecord](resp)
return &ret, err
}
// UpdateRecord update DNS record info
// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-record
func (a *DNSZonesAPI) UpdateRecord(ctx context.Context, zoneID, recordID string, request api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody) (*api.DNSRecord, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID+"/records/"+recordID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSRecord](resp)
return &ret, err
}
// DeleteRecord delete DNS record
// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-record
func (a *DNSZonesAPI) DeleteRecord(ctx context.Context, zoneID, recordID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,460 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testZone = api.Zone{
Id: "zone123",
Name: "test-zone",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
}
testDNSRecord = api.DNSRecord{
Id: "record123",
Name: "www",
Content: "192.168.1.1",
Type: api.DNSRecordTypeA,
Ttl: 300,
}
)
func TestDNSZone_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.Zone{testZone})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListZones(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testZone, ret[0])
})
}
func TestDNSZone_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListZones(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSZone_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testZone)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetZone(context.Background(), "zone123")
require.NoError(t, err)
assert.Equal(t, testZone, *ret)
})
}
func TestDNSZone_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetZone(context.Background(), "zone123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSZone_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiDnsZonesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "test-zone", req.Name)
assert.Equal(t, "example.com", req.Domain)
retBytes, _ := json.Marshal(testZone)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
enabled := true
ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{
Name: "test-zone",
Domain: "example.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
})
require.NoError(t, err)
assert.Equal(t, testZone, *ret)
})
}
func TestDNSZone_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{
Name: "test-zone",
Domain: "example.com",
})
assert.Error(t, err)
assert.Equal(t, "Invalid request", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSZone_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsZonesZoneIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "updated-zone", req.Name)
retBytes, _ := json.Marshal(testZone)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
enabled := true
ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "updated-zone",
Domain: "example.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
})
require.NoError(t, err)
assert.Equal(t, testZone, *ret)
})
}
func TestDNSZone_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "updated-zone",
Domain: "example.com",
})
assert.Error(t, err)
assert.Equal(t, "Invalid request", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSZone_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.DNSZones.DeleteZone(context.Background(), "zone123")
require.NoError(t, err)
})
}
func TestDNSZone_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.DNSZones.DeleteZone(context.Background(), "zone123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestDNSRecord_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.DNSRecord{testDNSRecord})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListRecords(context.Background(), "zone123")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testDNSRecord, ret[0])
})
}
func TestDNSRecord_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Zone not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListRecords(context.Background(), "zone123")
assert.Error(t, err)
assert.Equal(t, "Zone not found", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSRecord_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testDNSRecord)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123")
require.NoError(t, err)
assert.Equal(t, testDNSRecord, *ret)
})
}
func TestDNSRecord_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSRecord_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "www", req.Name)
assert.Equal(t, "192.168.1.1", req.Content)
assert.Equal(t, api.DNSRecordTypeA, req.Type)
retBytes, _ := json.Marshal(testDNSRecord)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "www",
Content: "192.168.1.1",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
require.NoError(t, err)
assert.Equal(t, testDNSRecord, *ret)
})
}
func TestDNSRecord_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "www",
Content: "192.168.1.1",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
assert.Error(t, err)
assert.Equal(t, "Invalid record", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSRecord_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "api", req.Name)
assert.Equal(t, "192.168.1.2", req.Content)
retBytes, _ := json.Marshal(testDNSRecord)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "api",
Content: "192.168.1.2",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
require.NoError(t, err)
assert.Equal(t, testDNSRecord, *ret)
})
}
func TestDNSRecord_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "api",
Content: "192.168.1.2",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
assert.Error(t, err)
assert.Equal(t, "Invalid record", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSRecord_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123")
require.NoError(t, err)
})
}
func TestDNSRecord_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestDNSZones_Integration(t *testing.T) {
enabled := true
zoneReq := api.ZoneRequest{
Name: "test-zone",
Domain: "test.example.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{"cs1tnh0hhcjnqoiuebeg"},
}
recordReq := api.DNSRecordRequest{
Name: "api.test.example.com",
Content: "192.168.1.100",
Type: api.DNSRecordTypeA,
Ttl: 300,
}
withBlackBoxServer(t, func(c *rest.Client) {
zone, err := c.DNSZones.CreateZone(context.Background(), zoneReq)
require.NoError(t, err)
assert.Equal(t, "test-zone", zone.Name)
assert.Equal(t, "test.example.com", zone.Domain)
zones, err := c.DNSZones.ListZones(context.Background())
require.NoError(t, err)
assert.Equal(t, *zone, zones[0])
getZone, err := c.DNSZones.GetZone(context.Background(), zone.Id)
require.NoError(t, err)
assert.Equal(t, *zone, *getZone)
zoneReq.Name = "updated-zone"
updatedZone, err := c.DNSZones.UpdateZone(context.Background(), zone.Id, zoneReq)
require.NoError(t, err)
assert.Equal(t, "updated-zone", updatedZone.Name)
record, err := c.DNSZones.CreateRecord(context.Background(), zone.Id, recordReq)
require.NoError(t, err)
assert.Equal(t, "api.test.example.com", record.Name)
assert.Equal(t, "192.168.1.100", record.Content)
records, err := c.DNSZones.ListRecords(context.Background(), zone.Id)
require.NoError(t, err)
assert.Equal(t, *record, records[0])
getRecord, err := c.DNSZones.GetRecord(context.Background(), zone.Id, record.Id)
require.NoError(t, err)
assert.Equal(t, *record, *getRecord)
recordReq.Name = "www.test.example.com"
updatedRecord, err := c.DNSZones.UpdateRecord(context.Background(), zone.Id, record.Id, recordReq)
require.NoError(t, err)
assert.Equal(t, "www.test.example.com", updatedRecord.Name)
err = c.DNSZones.DeleteRecord(context.Background(), zone.Id, record.Id)
require.NoError(t, err)
records, err = c.DNSZones.ListRecords(context.Background(), zone.Id)
require.NoError(t, err)
assert.Len(t, records, 0)
err = c.DNSZones.DeleteZone(context.Background(), zone.Id)
require.NoError(t, err)
zones, err = c.DNSZones.ListZones(context.Background())
require.NoError(t, err)
assert.Len(t, zones, 0)
})
}

View File

@@ -25,6 +25,8 @@ tags:
description: Interact with and view information about routes.
- name: DNS
description: Interact with and view information about DNS configuration.
- name: DNS Zones
description: Interact with and view information about custom DNS zones.
- name: Events
description: View information about the account and network events.
- name: Accounts
@@ -1779,6 +1781,100 @@ components:
example: ch8i4ug6lnn4g9hqv7m0
required:
- disabled_management_groups
ZoneRequest:
type: object
properties:
name:
description: Zone name identifier
type: string
maxLength: 255
minLength: 1
example: Office Zone
domain:
description: Zone domain (FQDN)
type: string
example: example.com
enabled:
description: Zone status
type: boolean
default: true
enable_search_domain:
description: Enable this zone as a search domain
type: boolean
example: false
distribution_groups:
description: Group IDs that defines groups of peers that will resolve this zone
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
required:
- name
- domain
- enable_search_domain
- distribution_groups
Zone:
allOf:
- type: object
properties:
id:
description: Zone ID
type: string
example: ch8i4ug6lnn4g9hqv7m0
records:
description: DNS records associated with this zone
type: array
items:
$ref: '#/components/schemas/DNSRecord'
required:
- id
- enabled
- records
- $ref: '#/components/schemas/ZoneRequest'
DNSRecordType:
type: string
description: DNS record type
enum:
- A
- AAAA
- CNAME
example: A
DNSRecordRequest:
type: object
properties:
name:
description: FQDN for the DNS record. Must be a subdomain within or match the zone's domain.
type: string
example: www.example.com
type:
$ref: '#/components/schemas/DNSRecordType'
content:
description: DNS record content (IP address for A/AAAA, domain for CNAME)
type: string
maxLength: 255
minLength: 1
example: 192.168.1.1
ttl:
description: Time to live in seconds
type: integer
minimum: 0
example: 300
required:
- name
- type
- content
- ttl
DNSRecord:
allOf:
- type: object
properties:
id:
description: DNS record ID
type: string
example: ch8i4ug6lnn4g9hqv7m0
required:
- id
- $ref: '#/components/schemas/DNSRecordRequest'
Event:
type: object
properties:
@@ -4733,6 +4829,347 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/zones:
get:
summary: List all DNS Zones
description: Returns a list of all custom DNS zones
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of DNS Zones
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Zone'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a DNS Zone
description: Creates a new custom DNS zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: A DNS zone object
content:
'application/json':
schema:
$ref: '#/components/schemas/ZoneRequest'
responses:
'200':
description: A JSON Object of the created DNS Zone
content:
application/json:
schema:
$ref: '#/components/schemas/Zone'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/zones/{zoneId}:
get:
summary: Retrieve a DNS Zone
description: Returns information about a specific DNS zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
responses:
'200':
description: A JSON Object of a DNS Zone
content:
application/json:
schema:
$ref: '#/components/schemas/Zone'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a DNS Zone
description: Updates a custom DNS zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
requestBody:
description: A DNS zone object
content:
'application/json':
schema:
$ref: '#/components/schemas/ZoneRequest'
responses:
'200':
description: A JSON Object of the updated DNS Zone
content:
application/json:
schema:
$ref: '#/components/schemas/Zone'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a DNS Zone
description: Deletes a custom DNS zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
responses:
'200':
description: Zone deletion successful
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/zones/{zoneId}/records:
get:
summary: List all DNS Records
description: Returns a list of all DNS records in a zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
responses:
'200':
description: A JSON Array of DNS Records
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/DNSRecord'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a DNS Record
description: Creates a new DNS record in a zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
requestBody:
description: A DNS record object
content:
'application/json':
schema:
$ref: '#/components/schemas/DNSRecordRequest'
responses:
'200':
description: A JSON Object of the created DNS Record
content:
application/json:
schema:
$ref: '#/components/schemas/DNSRecord'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/zones/{zoneId}/records/{recordId}:
get:
summary: Retrieve a DNS Record
description: Returns information about a specific DNS record
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
- in: path
name: recordId
required: true
schema:
type: string
description: The unique identifier of a DNS record
example: chacbco6lnnbn6cg5s92
responses:
'200':
description: A JSON Object of a DNS Record
content:
application/json:
schema:
$ref: '#/components/schemas/DNSRecord'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a DNS Record
description: Updates a DNS record in a zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
- in: path
name: recordId
required: true
schema:
type: string
description: The unique identifier of a DNS record
example: chacbco6lnnbn6cg5s92
requestBody:
description: A DNS record object
content:
'application/json':
schema:
$ref: '#/components/schemas/DNSRecordRequest'
responses:
'200':
description: A JSON Object of the updated DNS Record
content:
application/json:
schema:
$ref: '#/components/schemas/DNSRecord'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a DNS Record
description: Deletes a DNS record from a zone
tags: [ DNS Zones ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: zoneId
required: true
schema:
type: string
description: The unique identifier of a zone
example: chacbco6lnnbn6cg5s91
- in: path
name: recordId
required: true
schema:
type: string
description: The unique identifier of a DNS record
example: chacbco6lnnbn6cg5s92
responses:
'200':
description: Record deletion successful
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/events/audit:
get:
summary: List all Audit Events

View File

@@ -12,6 +12,13 @@ const (
TokenAuthScopes = "TokenAuth.Scopes"
)
// Defines values for DNSRecordType.
const (
DNSRecordTypeA DNSRecordType = "A"
DNSRecordTypeAAAA DNSRecordType = "AAAA"
DNSRecordTypeCNAME DNSRecordType = "CNAME"
)
// Defines values for EventActivityCode.
const (
EventActivityCodeAccountCreate EventActivityCode = "account.create"
@@ -427,6 +434,42 @@ type CreateSetupKeyRequest struct {
UsageLimit int `json:"usage_limit"`
}
// DNSRecord defines model for DNSRecord.
type DNSRecord struct {
// Content DNS record content (IP address for A/AAAA, domain for CNAME)
Content string `json:"content"`
// Id DNS record ID
Id string `json:"id"`
// Name FQDN for the DNS record. Must be a subdomain within or match the zone's domain.
Name string `json:"name"`
// Ttl Time to live in seconds
Ttl int `json:"ttl"`
// Type DNS record type
Type DNSRecordType `json:"type"`
}
// DNSRecordRequest defines model for DNSRecordRequest.
type DNSRecordRequest struct {
// Content DNS record content (IP address for A/AAAA, domain for CNAME)
Content string `json:"content"`
// Name FQDN for the DNS record. Must be a subdomain within or match the zone's domain.
Name string `json:"name"`
// Ttl Time to live in seconds
Ttl int `json:"ttl"`
// Type DNS record type
Type DNSRecordType `json:"type"`
}
// DNSRecordType DNS record type
type DNSRecordType string
// DNSSettings defines model for DNSSettings.
type DNSSettings struct {
// DisabledManagementGroups Groups whose DNS management is disabled
@@ -1999,6 +2042,48 @@ type UserRequest struct {
Role string `json:"role"`
}
// Zone defines model for Zone.
type Zone struct {
// DistributionGroups Group IDs that defines groups of peers that will resolve this zone
DistributionGroups []string `json:"distribution_groups"`
// Domain Zone domain (FQDN)
Domain string `json:"domain"`
// EnableSearchDomain Enable this zone as a search domain
EnableSearchDomain bool `json:"enable_search_domain"`
// Enabled Zone status
Enabled bool `json:"enabled"`
// Id Zone ID
Id string `json:"id"`
// Name Zone name identifier
Name string `json:"name"`
// Records DNS records associated with this zone
Records []DNSRecord `json:"records"`
}
// ZoneRequest defines model for ZoneRequest.
type ZoneRequest struct {
// DistributionGroups Group IDs that defines groups of peers that will resolve this zone
DistributionGroups []string `json:"distribution_groups"`
// Domain Zone domain (FQDN)
Domain string `json:"domain"`
// EnableSearchDomain Enable this zone as a search domain
EnableSearchDomain bool `json:"enable_search_domain"`
// Enabled Zone status
Enabled *bool `json:"enabled,omitempty"`
// Name Zone name identifier
Name string `json:"name"`
}
// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic.
type GetApiEventsNetworkTrafficParams struct {
// Page Page number
@@ -2083,6 +2168,18 @@ type PutApiDnsNameserversNsgroupIdJSONRequestBody = NameserverGroupRequest
// PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType.
type PutApiDnsSettingsJSONRequestBody = DNSSettings
// PostApiDnsZonesJSONRequestBody defines body for PostApiDnsZones for application/json ContentType.
type PostApiDnsZonesJSONRequestBody = ZoneRequest
// PutApiDnsZonesZoneIdJSONRequestBody defines body for PutApiDnsZonesZoneId for application/json ContentType.
type PutApiDnsZonesZoneIdJSONRequestBody = ZoneRequest
// PostApiDnsZonesZoneIdRecordsJSONRequestBody defines body for PostApiDnsZonesZoneIdRecords for application/json ContentType.
type PostApiDnsZonesZoneIdRecordsJSONRequestBody = DNSRecordRequest
// PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody defines body for PutApiDnsZonesZoneIdRecordsRecordId for application/json ContentType.
type PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody = DNSRecordRequest
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody = GroupRequest

View File

@@ -252,3 +252,13 @@ func NewOperationNotFoundError(operation operations.Operation) error {
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}
// NewZoneNotFoundError creates a new Error with NotFound type for a missing dns zone.
func NewZoneNotFoundError(zoneID string) error {
return Errorf(NotFound, "zone: %s not found", zoneID)
}
// NewDNSRecordNotFoundError creates a new Error with NotFound type for a missing dns record.
func NewDNSRecordNotFoundError(recordID string) error {
return Errorf(NotFound, "dns record: %s not found", recordID)
}