mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:24:18 -04:00
[management] Add custom dns zones (#4849)
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -14,6 +15,7 @@ type Repository interface {
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
|
||||
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
|
||||
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
13
management/internals/modules/zones/interface.go
Normal file
13
management/internals/modules/zones/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package zones
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
|
||||
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
|
||||
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
|
||||
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
|
||||
}
|
||||
161
management/internals/modules/zones/manager/api.go
Normal file
161
management/internals/modules/zones/manager/api.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager zones.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiZones := make([]*api.Zone, 0, len(allZones))
|
||||
for _, zone := range allZones {
|
||||
apiZones = append(apiZones, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
zone.ID = zoneID
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
229
management/internals/modules/zones/manager/manager.go
Normal file
229
management/internals/modules/zones/manager/manager.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
dnsDomain string
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||
}
|
||||
}
|
||||
if existingZone != nil {
|
||||
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
|
||||
}
|
||||
|
||||
for _, groupID := range zone.DistributionGroups {
|
||||
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.CreateZone(ctx, zone); err != nil {
|
||||
return fmt.Errorf("failed to create zone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
if zone.Domain != updatedZone.Domain {
|
||||
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
|
||||
}
|
||||
|
||||
zone.Name = updatedZone.Name
|
||||
zone.Enabled = updatedZone.Enabled
|
||||
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
|
||||
zone.DistributionGroups = updatedZone.DistributionGroups
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range zone.DistributionGroups {
|
||||
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.UpdateZone(ctx, zone); err != nil {
|
||||
return fmt.Errorf("failed to update zone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get records: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete zone dns records: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteZone(ctx, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete zone: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
|
||||
})
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range eventsToStore {
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
|
||||
if m.dnsDomain != "" && m.dnsDomain == domain {
|
||||
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.DNSDomain != "" && settings.DNSDomain == domain {
|
||||
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
553
management/internals/modules/zones/manager/manager_test.go
Normal file
553
management/internals/modules/zones/manager/manager_test.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testZoneID = "test-zone-id"
|
||||
testGroupID = "test-group-id"
|
||||
testDNSDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
Name: "Test Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
dnsDomain: testDNSDomain,
|
||||
}
|
||||
|
||||
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone1)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
|
||||
err = testStore.CreateZone(ctx, zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, zone1.ID, result[0].ID)
|
||||
assert.Equal(t, zone2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, zone.ID, result.ID)
|
||||
assert.Equal(t, zone.Name, result.Name)
|
||||
assert.Equal(t, zone.Domain, result.Domain)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
assert.Equal(t, testAccountID, result.AccountID)
|
||||
assert.Equal(t, inputZone.Name, result.Name)
|
||||
assert.Equal(t, inputZone.Domain, result.Domain)
|
||||
assert.Equal(t, inputZone.Enabled, result.Enabled)
|
||||
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("invalid group", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "new.example.com",
|
||||
DistributionGroups: []string{"invalid-group"},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate domain", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "New Zone",
|
||||
Domain: "duplicate.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, s.Type())
|
||||
})
|
||||
|
||||
t.Run("peer DNS domain conflict", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
account, err := testStore.GetAccount(ctx, testAccountID)
|
||||
require.NoError(t, err)
|
||||
account.Settings.DNSDomain = "peers.example.com"
|
||||
err = testStore.SaveAccount(ctx, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "Test Zone",
|
||||
Domain: "peers.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("default DNS domain conflict", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputZone := &zones.Zone{
|
||||
Name: "Test Zone",
|
||||
Domain: testDNSDomain,
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_UpdateZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: existingZone.ID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, existingZone.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneUpdated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, updatedZone.Name, result.Name)
|
||||
assert.Equal(t, updatedZone.Enabled, result.Enabled)
|
||||
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
})
|
||||
|
||||
t.Run("domain change not allowed", func(t *testing.T) {
|
||||
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, existingZone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: existingZone.ID,
|
||||
Name: "Test Zone",
|
||||
Domain: "different.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{testGroupID},
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "zone domain cannot be updated")
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.InvalidArgument, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: testZoneID,
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedZone := &zones.Zone{
|
||||
ID: "non-existent-zone",
|
||||
Name: "Updated Name",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_DeleteZone(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success with records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCallCount := 0
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCallCount++
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
}
|
||||
|
||||
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, storeEventCallCount)
|
||||
|
||||
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
|
||||
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, zoneRecords)
|
||||
})
|
||||
|
||||
t.Run("success without records", func(t *testing.T) {
|
||||
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err := testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, zone.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSZoneDeleted, activityID)
|
||||
}
|
||||
|
||||
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
|
||||
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("zone not found", func(t *testing.T) {
|
||||
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
13
management/internals/modules/zones/records/interface.go
Normal file
13
management/internals/modules/zones/records/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
|
||||
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
|
||||
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
|
||||
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
|
||||
}
|
||||
191
management/internals/modules/zones/records/manager/api.go
Normal file
191
management/internals/modules/zones/records/manager/api.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager records.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
|
||||
for _, record := range allRecords {
|
||||
apiRecords = append(apiRecords, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
record.ID = recordID
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
recordID := mux.Vars(r)["recordId"]
|
||||
if recordID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
236
management/internals/modules/zones/records/manager/manager.go
Normal file
236
management/internals/modules/zones/records/manager/manager.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
|
||||
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
err = validateRecordConflicts(ctx, transaction, zone, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to create dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
var record *records.Record
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
|
||||
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
|
||||
|
||||
record.Name = updatedRecord.Name
|
||||
record.Type = updatedRecord.Type
|
||||
record.Content = updatedRecord.Content
|
||||
record.TTL = updatedRecord.TTL
|
||||
|
||||
if hasChanges {
|
||||
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
|
||||
return fmt.Errorf("failed to update dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
var zone *zones.Zone
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get zone: %w", err)
|
||||
}
|
||||
|
||||
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete dns record: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRecordConflicts checks for duplicate records and CNAME conflicts
|
||||
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
|
||||
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
|
||||
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
|
||||
}
|
||||
|
||||
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing records: %w", err)
|
||||
}
|
||||
|
||||
for _, existing := range existingRecords {
|
||||
if existing.ID == record.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if existing.Type == record.Type && existing.Content == record.Content {
|
||||
return status.Errorf(status.AlreadyExists, "identical record already exists")
|
||||
}
|
||||
|
||||
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
|
||||
return status.Errorf(status.InvalidArgument,
|
||||
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testRecordID = "test-record-id"
|
||||
testGroupID = "test-group-id"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
Name: "Test Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
|
||||
err = testStore.CreateZone(ctx, zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockAccountManager := &mock_server.MockAccountManager{}
|
||||
mockPermissionsManager := permissions.NewMockManager(ctrl)
|
||||
|
||||
manager := &managerImpl{
|
||||
store: testStore,
|
||||
accountManager: mockAccountManager,
|
||||
permissionsManager: mockPermissionsManager,
|
||||
}
|
||||
|
||||
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetAllRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, record1.ID, result[0].ID)
|
||||
assert.Equal(t, record2.ID, result[1].ID)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("permission validation error", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, status.Errorf(status.Internal, "permission check failed"))
|
||||
|
||||
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_GetRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.ID, result.ID)
|
||||
assert.Equal(t, record.Name, result.Name)
|
||||
assert.Equal(t, record.Type, result.Type)
|
||||
assert.Equal(t, record.Content, result.Content)
|
||||
assert.Equal(t, record.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_CreateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success - A record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
assert.Equal(t, testAccountID, result.AccountID)
|
||||
assert.Equal(t, zone.ID, result.ZoneID)
|
||||
assert.Equal(t, inputRecord.Name, result.Name)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
assert.Equal(t, inputRecord.TTL, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("success - AAAA record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "ipv6.example.com",
|
||||
Type: records.RecordTypeAAAA,
|
||||
Content: "2001:db8::1",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("success - CNAME record", func(t *testing.T) {
|
||||
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "www.example.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "example.com",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordCreated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, inputRecord.Type, result.Type)
|
||||
assert.Equal(t, inputRecord.Content, result.Content)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record name not in zone", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.different.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "does not belong to zone")
|
||||
})
|
||||
|
||||
t.Run("duplicate record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "identical record already exists")
|
||||
})
|
||||
|
||||
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRecord := &records.Record{
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "example.com",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_UpdateRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: existingRecord.ID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100", // Changed IP
|
||||
TTL: 600, // Changed TTL
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, existingRecord.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordUpdated, activityID)
|
||||
}
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, updatedRecord.Content, result.Content)
|
||||
assert.Equal(t, updatedRecord.TTL, result.TTL)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
})
|
||||
|
||||
t.Run("update only TTL - no validation", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, existingRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: existingRecord.ID,
|
||||
Name: existingRecord.Name,
|
||||
Type: existingRecord.Type,
|
||||
Content: existingRecord.Content,
|
||||
TTL: 600, // Only TTL changed
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
// Event should be stored
|
||||
}
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, 600, result.TTL)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: testRecordID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(false, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: "non-existent-record",
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.100",
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("update creates duplicate", func(t *testing.T) {
|
||||
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
|
||||
err = testStore.CreateDNSRecord(ctx, record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord := &records.Record{
|
||||
ID: record2.ID,
|
||||
Name: "api.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
|
||||
Return(true, nil)
|
||||
|
||||
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "identical record already exists")
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerImpl_DeleteRecord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err := testStore.CreateDNSRecord(ctx, record)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
storeEventCalled := false
|
||||
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
storeEventCalled = true
|
||||
assert.Equal(t, testUserID, initiatorID)
|
||||
assert.Equal(t, record.ID, targetID)
|
||||
assert.Equal(t, testAccountID, accountID)
|
||||
assert.Equal(t, activity.DNSRecordDeleted, activityID)
|
||||
}
|
||||
|
||||
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storeEventCalled, "StoreEvent should have been called")
|
||||
|
||||
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("permission denied", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(false, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, status.PermissionDenied, s.Type())
|
||||
})
|
||||
|
||||
t.Run("record not found", func(t *testing.T) {
|
||||
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
|
||||
defer cleanup()
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPermissionsManager.EXPECT().
|
||||
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
|
||||
Return(true, nil)
|
||||
|
||||
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
129
management/internals/modules/zones/records/record.go
Normal file
129
management/internals/modules/zones/records/record.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package records
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type RecordType string
|
||||
|
||||
const (
|
||||
RecordTypeA RecordType = "A"
|
||||
RecordTypeAAAA RecordType = "AAAA"
|
||||
RecordTypeCNAME RecordType = "CNAME"
|
||||
)
|
||||
|
||||
type Record struct {
|
||||
AccountID string `gorm:"index"`
|
||||
ZoneID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
Name string
|
||||
Type RecordType
|
||||
Content string
|
||||
TTL int
|
||||
}
|
||||
|
||||
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
|
||||
return &Record{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
ZoneID: zoneID,
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
Content: content,
|
||||
TTL: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Record) ToAPIResponse() *api.DNSRecord {
|
||||
recordType := api.DNSRecordType(r.Type)
|
||||
return &api.DNSRecord{
|
||||
Id: r.ID,
|
||||
Name: r.Name,
|
||||
Type: recordType,
|
||||
Content: r.Content,
|
||||
Ttl: r.TTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
|
||||
r.Name = req.Name
|
||||
r.Type = RecordType(req.Type)
|
||||
r.Content = req.Content
|
||||
r.TTL = req.Ttl
|
||||
}
|
||||
|
||||
func (r *Record) Validate() error {
|
||||
if r.Name == "" {
|
||||
return errors.New("record name is required")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(r.Name) {
|
||||
return errors.New("invalid record name format")
|
||||
}
|
||||
|
||||
if r.Type == "" {
|
||||
return errors.New("record type is required")
|
||||
}
|
||||
|
||||
switch r.Type {
|
||||
case RecordTypeA:
|
||||
if err := validateIPv4(r.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
case RecordTypeAAAA:
|
||||
if err := validateIPv6(r.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
case RecordTypeCNAME:
|
||||
if !util.IsValidDomain(r.Content) {
|
||||
return errors.New("invalid CNAME record format")
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
||||
}
|
||||
|
||||
if r.TTL < 0 {
|
||||
return errors.New("TTL cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
|
||||
return map[string]any{
|
||||
"name": r.Name,
|
||||
"type": string(r.Type),
|
||||
"content": r.Content,
|
||||
"ttl": r.TTL,
|
||||
"zone_id": zoneID,
|
||||
"zone_name": zoneName,
|
||||
}
|
||||
}
|
||||
|
||||
func validateIPv4(content string) error {
|
||||
if content == "" {
|
||||
return errors.New("A record is required") //nolint:staticcheck
|
||||
}
|
||||
ip := net.ParseIP(content)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIPv6(content string) error {
|
||||
if content == "" {
|
||||
return errors.New("AAAA record is required")
|
||||
}
|
||||
ip := net.ParseIP(content)
|
||||
if ip == nil || ip.To4() != nil {
|
||||
return errors.New("AAAA record must be a valid IPv6 address")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
89
management/internals/modules/zones/zone.go
Normal file
89
management/internals/modules/zones/zone.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package zones
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type Zone struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string
|
||||
Enabled bool
|
||||
EnableSearchDomain bool
|
||||
DistributionGroups []string `gorm:"serializer:json"`
|
||||
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
|
||||
}
|
||||
|
||||
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
|
||||
return &Zone{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
Enabled: enabled,
|
||||
EnableSearchDomain: enableSearchDomain,
|
||||
DistributionGroups: distributionGroups,
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Zone) ToAPIResponse() *api.Zone {
|
||||
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
|
||||
for _, record := range z.Records {
|
||||
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
|
||||
apiRecords = append(apiRecords, *apiRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.Zone{
|
||||
DistributionGroups: z.DistributionGroups,
|
||||
Domain: z.Domain,
|
||||
EnableSearchDomain: z.EnableSearchDomain,
|
||||
Enabled: z.Enabled,
|
||||
Id: z.ID,
|
||||
Name: z.Name,
|
||||
Records: apiRecords,
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
|
||||
z.Name = req.Name
|
||||
z.Domain = req.Domain
|
||||
z.EnableSearchDomain = req.EnableSearchDomain
|
||||
z.DistributionGroups = req.DistributionGroups
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
z.Enabled = enabled
|
||||
}
|
||||
|
||||
func (z *Zone) Validate() error {
|
||||
if z.Name == "" {
|
||||
return errors.New("zone name is required")
|
||||
}
|
||||
if len(z.Name) > 255 {
|
||||
return errors.New("zone name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(z.Domain) {
|
||||
return errors.New("invalid zone domain format")
|
||||
}
|
||||
|
||||
if len(z.DistributionGroups) == 0 {
|
||||
return errors.New("at least one distribution group is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *Zone) EventMeta() map[string]any {
|
||||
return map[string]any{"name": z.Name, "domain": z.Domain}
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ZonesManager() zones.Manager {
|
||||
return Create(s, func() zones.Manager {
|
||||
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RecordsManager() records.Manager {
|
||||
return Create(s, func() records.Manager {
|
||||
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -374,9 +374,10 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
|
||||
@@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -388,7 +388,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return newSettings, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -402,6 +402,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
|
||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
|
||||
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing zone: %w", err)
|
||||
}
|
||||
}
|
||||
if existingZone != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
|
||||
}
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
@@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
@@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
ctx := context.Background()
|
||||
err = manager.Store.CreateZone(ctx, &zones.Zone{
|
||||
ID: "test-zone-id",
|
||||
AccountID: accountID,
|
||||
Name: "Test Zone",
|
||||
Domain: "custom.example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{},
|
||||
})
|
||||
require.NoError(t, err, "unable to create custom DNS zone")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
|
||||
DNSDomain: "custom.example.com",
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
|
||||
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
|
||||
}
|
||||
|
||||
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
|
||||
@@ -187,6 +187,14 @@ const (
|
||||
IdentityProviderUpdated Activity = 94
|
||||
IdentityProviderDeleted Activity = 95
|
||||
|
||||
DNSZoneCreated Activity = 96
|
||||
DNSZoneUpdated Activity = 97
|
||||
DNSZoneDeleted Activity = 98
|
||||
|
||||
DNSRecordCreated Activity = 99
|
||||
DNSRecordUpdated Activity = 100
|
||||
DNSRecordDeleted Activity = 101
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -303,6 +311,14 @@ var activityMap = map[Activity]Code{
|
||||
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
|
||||
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
|
||||
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
|
||||
|
||||
DNSZoneCreated: {"DNS zone created", "dns.zone.create"},
|
||||
DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"},
|
||||
DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"},
|
||||
|
||||
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
|
||||
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
|
||||
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -15,7 +15,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
@@ -56,7 +59,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -138,6 +141,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -298,8 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
@@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -123,6 +125,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&zones.Zone{}, &records.Record{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -4179,3 +4182,184 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error {
|
||||
result := s.db.Create(zone)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to create zone to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
|
||||
result := s.db.Select("*").Save(zone)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update zone to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error {
|
||||
result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete zone from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewZoneNotFoundError(zoneID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var zone *zones.Zone
|
||||
result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewZoneNotFoundError(zoneID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone from store")
|
||||
}
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) {
|
||||
var zone *zones.Zone
|
||||
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewZoneNotFoundError(domain)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone by domain from store")
|
||||
}
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var zones []*zones.Zone
|
||||
result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zones from store")
|
||||
}
|
||||
|
||||
return zones, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error {
|
||||
result := s.db.Create(record)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to create dns record to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error {
|
||||
result := s.db.Select("*").Save(record)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update dns record to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error {
|
||||
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete dns record from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewDNSRecordNotFoundError(recordID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var record *records.Record
|
||||
result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewDNSRecordNotFoundError(recordID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns record from store")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var recordsList []*records.Record
|
||||
result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone dns records from store")
|
||||
}
|
||||
|
||||
return recordsList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var recordsList []*records.Record
|
||||
result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store")
|
||||
}
|
||||
|
||||
return recordsList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error {
|
||||
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete zone dns records from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
||||
}
|
||||
|
||||
func TestSqlStore_CreateZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, zone.ID, savedZone.ID)
|
||||
assert.Equal(t, zone.Name, savedZone.Name)
|
||||
assert.Equal(t, zone.Domain, savedZone.Domain)
|
||||
assert.Equal(t, zone.Enabled, savedZone.Enabled)
|
||||
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
|
||||
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
zoneID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing zone",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing zone",
|
||||
accountID: accountID,
|
||||
zoneID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty zone ID",
|
||||
accountID: accountID,
|
||||
zoneID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, savedZone)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, tt.zoneID, savedZone.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountZones(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone1)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
|
||||
err = store.CreateZone(context.Background(), zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, allZones)
|
||||
assert.GreaterOrEqual(t, len(allZones), 2)
|
||||
|
||||
zoneIDs := make(map[string]bool)
|
||||
for _, z := range allZones {
|
||||
zoneIDs[z.ID] = true
|
||||
}
|
||||
assert.True(t, zoneIDs[zone1.ID])
|
||||
assert.True(t, zoneIDs[zone2.ID])
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneByDomain(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
domain string
|
||||
expectError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing zone by domain",
|
||||
accountID: accountID,
|
||||
domain: "example.com",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing zone domain",
|
||||
accountID: accountID,
|
||||
domain: "non-existing.com",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty domain",
|
||||
accountID: accountID,
|
||||
domain: "",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
{
|
||||
name: "retrieve with different account ID",
|
||||
accountID: otherAccountID,
|
||||
domain: "example.com",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.errorType, sErr.Type())
|
||||
require.Nil(t, savedZone)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, tt.domain, savedZone.Domain)
|
||||
assert.Equal(t, zone.ID, savedZone.ID)
|
||||
assert.Equal(t, zone.Name, savedZone.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone.Name = "Updated Zone"
|
||||
zone.Domain = "updated.com"
|
||||
zone.Enabled = false
|
||||
zone.EnableSearchDomain = true
|
||||
zone.DistributionGroups = []string{"group2", "group3"}
|
||||
|
||||
err = store.UpdateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedZone)
|
||||
assert.Equal(t, "Updated Zone", updatedZone.Name)
|
||||
assert.Equal(t, "updated.com", updatedZone.Domain)
|
||||
assert.False(t, updatedZone.Enabled)
|
||||
assert.True(t, updatedZone.EnableSearchDomain)
|
||||
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteZone(context.Background(), accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, deletedZone)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
}
|
||||
|
||||
func TestSqlStore_CreateDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedRecord)
|
||||
assert.Equal(t, record.ID, savedRecord.ID)
|
||||
assert.Equal(t, record.Name, savedRecord.Name)
|
||||
assert.Equal(t, record.Type, savedRecord.Type)
|
||||
assert.Equal(t, record.Content, savedRecord.Content)
|
||||
assert.Equal(t, record.TTL, savedRecord.TTL)
|
||||
assert.Equal(t, zone.ID, savedRecord.ZoneID)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
zoneID string
|
||||
recordID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing record",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: record.ID,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing record",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty record ID",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, savedRecord)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedRecord)
|
||||
assert.Equal(t, tt.recordID, savedRecord.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordA)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordAAAA)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordCNAME)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, allRecords)
|
||||
assert.Equal(t, 3, len(allRecords))
|
||||
|
||||
recordIDs := make(map[string]bool)
|
||||
for _, r := range allRecords {
|
||||
recordIDs[r.ID] = true
|
||||
}
|
||||
assert.True(t, recordIDs[recordA.ID])
|
||||
assert.True(t, recordIDs[recordAAAA.ID])
|
||||
assert.True(t, recordIDs[recordCNAME.ID])
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||
err = store.CreateDNSRecord(context.Background(), record3)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, recordsByName)
|
||||
assert.Equal(t, 2, len(recordsByName))
|
||||
|
||||
for _, r := range recordsByName {
|
||||
assert.Equal(t, "www.example.com", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
record.Name = "api.example.com"
|
||||
record.Content = "192.168.1.100"
|
||||
record.TTL = 600
|
||||
|
||||
err = store.UpdateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedRecord)
|
||||
assert.Equal(t, "api.example.com", updatedRecord.Name)
|
||||
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
|
||||
assert.Equal(t, 600, updatedRecord.TTL)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, deletedRecord)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||
err = store.CreateDNSRecord(context.Background(), record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(allRecords))
|
||||
|
||||
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(remainingRecords))
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -209,6 +211,21 @@ type Store interface {
|
||||
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
||||
SetFieldEncrypt(enc *crypt.FieldEncrypt)
|
||||
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
||||
|
||||
CreateZone(ctx context.Context, zone *zones.Zone) error
|
||||
UpdateZone(ctx context.Context, zone *zones.Zone) error
|
||||
DeleteZone(ctx context.Context, accountID, zoneID string) error
|
||||
GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error)
|
||||
GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error)
|
||||
GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error)
|
||||
|
||||
CreateDNSRecord(ctx context.Context, record *records.Record) error
|
||||
UpdateDNSRecord(ctx context.Context, record *records.Record) error
|
||||
DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error
|
||||
GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error)
|
||||
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
|
||||
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
|
||||
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -150,17 +152,16 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.GetPeerGroups(peerID)
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
@@ -274,6 +275,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
@@ -294,6 +296,8 @@ func (a *Account) GetPeerNetworkMap(
|
||||
}
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -307,7 +311,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
@@ -323,6 +327,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
@@ -330,6 +335,10 @@ func (a *Account) GetPeerNetworkMap(
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
|
||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
zones = append(zones, filteredAccountZones...)
|
||||
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
@@ -1881,3 +1890,66 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
// filterPeerAppliedZones filters account zones based on the peer's group membership
|
||||
func filterPeerAppliedZones(ctx context.Context, accountZones []*zones.Zone, peerGroups LookupMap) []nbdns.CustomZone {
|
||||
var customZones []nbdns.CustomZone
|
||||
|
||||
if len(peerGroups) == 0 {
|
||||
return customZones
|
||||
}
|
||||
|
||||
for _, zone := range accountZones {
|
||||
if !zone.Enabled || len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hasAccess := false
|
||||
for _, distGroupID := range zone.DistributionGroups {
|
||||
if _, found := peerGroups[distGroupID]; found {
|
||||
hasAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAccess {
|
||||
continue
|
||||
}
|
||||
|
||||
simpleRecords := make([]nbdns.SimpleRecord, 0, len(zone.Records))
|
||||
for _, record := range zone.Records {
|
||||
var recordType int
|
||||
rData := record.Content
|
||||
|
||||
switch record.Type {
|
||||
case records.RecordTypeA:
|
||||
recordType = int(dns.TypeA)
|
||||
case records.RecordTypeAAAA:
|
||||
recordType = int(dns.TypeAAAA)
|
||||
case records.RecordTypeCNAME:
|
||||
recordType = int(dns.TypeCNAME)
|
||||
rData = dns.Fqdn(record.Content)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unknown DNS record type %s for record %s", record.Type, record.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
simpleRecords = append(simpleRecords, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(record.Name),
|
||||
Type: recordType,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: record.TTL,
|
||||
RData: rData,
|
||||
})
|
||||
}
|
||||
|
||||
customZones = append(customZones, nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(zone.Domain),
|
||||
Records: simpleRecords,
|
||||
SearchDomainDisabled: !zone.EnableSearchDomain,
|
||||
NonAuthoritative: true,
|
||||
})
|
||||
}
|
||||
|
||||
return customZones
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -1425,3 +1427,515 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_filterPeerAppliedZones(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountZones []*zones.Zone
|
||||
peerGroups LookupMap
|
||||
expected []nbdns.CustomZone
|
||||
}{
|
||||
{
|
||||
name: "empty peer groups returns empty custom zones",
|
||||
accountZones: []*zones.Zone{},
|
||||
peerGroups: LookupMap{},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "peer has access to zone with A record",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.example.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer has access to zone with search domain enabled",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "internal.local",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "api.internal.local",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "10.0.0.1",
|
||||
TTL: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "internal.local.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "api.internal.local.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 600,
|
||||
RData: "10.0.0.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer has no access to zone",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "private.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group2"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "secret.private.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "disabled zone is filtered out",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "disabled.com",
|
||||
Enabled: false,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.disabled.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "zone with no records is filtered out",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "empty.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{},
|
||||
},
|
||||
{
|
||||
name: "peer has access via multiple groups",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "multi.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1", "group2", "group3"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.multi.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group2": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "multi.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.multi.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple zones with mixed access",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "allowed.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.allowed.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "zone2",
|
||||
Domain: "denied.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group2"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.denied.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "allowed.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.allowed.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zone with multiple record types",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "mixed.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.mixed.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "ipv6.mixed.com",
|
||||
Type: records.RecordTypeAAAA,
|
||||
Content: "2001:db8::1",
|
||||
TTL: 600,
|
||||
},
|
||||
{
|
||||
ID: "record3",
|
||||
Name: "alias.mixed.com",
|
||||
Type: records.RecordTypeCNAME,
|
||||
Content: "www.mixed.com",
|
||||
TTL: 900,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "mixed.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.mixed.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
Name: "ipv6.mixed.com.",
|
||||
Type: int(dns.TypeAAAA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 600,
|
||||
RData: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
Name: "alias.mixed.com.",
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 900,
|
||||
RData: "www.mixed.com.",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple zones both accessible",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "first.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: true,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.first.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "zone2",
|
||||
Domain: "second.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.second.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "first.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.first.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: false,
|
||||
},
|
||||
{
|
||||
Domain: "second.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.second.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 600,
|
||||
RData: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zone with multiple records of same type",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "multi-a.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.multi-a.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.multi-a.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "multi-a.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.multi-a.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
Name: "www.multi-a.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peer in multiple groups accessing different zones",
|
||||
accountZones: []*zones.Zone{
|
||||
{
|
||||
ID: "zone1",
|
||||
Domain: "zone1.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record1",
|
||||
Name: "www.zone1.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.1",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "zone2",
|
||||
Domain: "zone2.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group2"},
|
||||
Records: []*records.Record{
|
||||
{
|
||||
ID: "record2",
|
||||
Name: "www.zone2.com",
|
||||
Type: records.RecordTypeA,
|
||||
Content: "192.168.1.2",
|
||||
TTL: 300,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
peerGroups: LookupMap{"group1": struct{}{}, "group2": struct{}{}},
|
||||
expected: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "zone1.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.zone1.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
{
|
||||
Domain: "zone2.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "www.zone2.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.2",
|
||||
},
|
||||
},
|
||||
SearchDomainDisabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterPeerAppliedZones(ctx, tt.accountZones, tt.peerGroups)
|
||||
require.Equal(t, len(tt.expected), len(result), "number of custom zones should match")
|
||||
|
||||
for i, expectedZone := range tt.expected {
|
||||
assert.Equal(t, expectedZone.Domain, result[i].Domain, "domain should match")
|
||||
assert.Equal(t, expectedZone.SearchDomainDisabled, result[i].SearchDomainDisabled, "search domain disabled flag should match")
|
||||
assert.Equal(t, len(expectedZone.Records), len(result[i].Records), "number of records should match")
|
||||
|
||||
for j, expectedRecord := range expectedZone.Records {
|
||||
assert.Equal(t, expectedRecord.Name, result[i].Records[j].Name, "record name should match")
|
||||
assert.Equal(t, expectedRecord.Type, result[i].Records[j].Type, "record type should match")
|
||||
assert.Equal(t, expectedRecord.Class, result[i].Records[j].Class, "record class should match")
|
||||
assert.Equal(t, expectedRecord.TTL, result[i].Records[j].TTL, "record TTL should match")
|
||||
assert.Equal(t, expectedRecord.RData, result[i].Records[j].RData, "record RData should match")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
@@ -25,11 +26,12 @@ func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
|
||||
@@ -70,13 +70,13 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -115,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -124,7 +124,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
for range b.N {
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
for _, peerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -177,7 +177,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -185,7 +185,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
require.NoError(t, err, "error adding peer to cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -240,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -250,7 +250,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -317,7 +317,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -325,7 +325,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
require.NoError(t, err, "error adding router to cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -402,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -412,7 +412,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -458,7 +458,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -466,7 +466,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
err = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
require.NoError(t, err, "error deleting peer from cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -537,7 +537,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -545,7 +545,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
err = builder.OnPeerDeleted(account, deletedRouterID)
|
||||
require.NoError(t, err, "error deleting routing peer from cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
@@ -597,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -607,7 +607,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -944,7 +944,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -1033,7 +1034,7 @@ func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
|
||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone,
|
||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
@@ -1057,7 +1058,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
|
||||
nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers)
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||
@@ -1074,8 +1075,8 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{},
|
||||
) *NetworkMap {
|
||||
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -1125,13 +1126,26 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
}
|
||||
|
||||
finalDNSConfig := *dnsConfig
|
||||
if finalDNSConfig.ServiceEnable && customZone.Domain != "" {
|
||||
if finalDNSConfig.ServiceEnable {
|
||||
var zones []nbdns.CustomZone
|
||||
records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: customZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
|
||||
peerGroupsSlice := b.cache.peerToGroups[peer.ID]
|
||||
peerGroups := make(LookupMap, len(peerGroupsSlice))
|
||||
for _, groupID := range peerGroupsSlice {
|
||||
peerGroups[groupID] = struct{}{}
|
||||
}
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
|
||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
zones = append(zones, filteredAccountZones...)
|
||||
|
||||
finalDNSConfig.CustomZones = zones
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package util
|
||||
|
||||
import "regexp"
|
||||
|
||||
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
|
||||
|
||||
// Difference returns the elements in `a` that aren't in `b`.
|
||||
func Difference(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
@@ -50,3 +54,10 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsValidDomain(domain string) bool {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
return domainRegex.MatchString(domain)
|
||||
}
|
||||
|
||||
@@ -59,9 +59,13 @@ type Client struct {
|
||||
Routes *RoutesAPI
|
||||
|
||||
// DNS NetBird DNS APIs
|
||||
// see more: https://docs.netbird.io/api/resources/routes
|
||||
// see more: https://docs.netbird.io/api/resources/dns
|
||||
DNS *DNSAPI
|
||||
|
||||
// DNSZones NetBird DNS Zones APIs
|
||||
// see more: https://docs.netbird.io/api/resources/dns-zones
|
||||
DNSZones *DNSZonesAPI
|
||||
|
||||
// GeoLocation NetBird Geo Location APIs
|
||||
// see more: https://docs.netbird.io/api/resources/geo-locations
|
||||
GeoLocation *GeoLocationAPI
|
||||
@@ -113,6 +117,7 @@ func (c *Client) initialize() {
|
||||
c.Networks = &NetworksAPI{c}
|
||||
c.Routes = &RoutesAPI{c}
|
||||
c.DNS = &DNSAPI{c}
|
||||
c.DNSZones = &DNSZonesAPI{c}
|
||||
c.GeoLocation = &GeoLocationAPI{c}
|
||||
c.Events = &EventsAPI{c}
|
||||
}
|
||||
|
||||
170
shared/management/client/rest/dns_zones.go
Normal file
170
shared/management/client/rest/dns_zones.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// DNSZonesAPI APIs for DNS Zones Management, do not use directly
|
||||
type DNSZonesAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// ListZones list all DNS zones
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-zones
|
||||
func (a *DNSZonesAPI) ListZones(ctx context.Context) ([]api.Zone, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Zone](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// GetZone get DNS zone info
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-zone
|
||||
func (a *DNSZonesAPI) GetZone(ctx context.Context, zoneID string) (*api.Zone, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Zone](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateZone create new DNS zone
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-zone
|
||||
func (a *DNSZonesAPI) CreateZone(ctx context.Context, request api.PostApiDnsZonesJSONRequestBody) (*api.Zone, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Zone](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateZone update DNS zone info
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-zone
|
||||
func (a *DNSZonesAPI) UpdateZone(ctx context.Context, zoneID string, request api.PutApiDnsZonesZoneIdJSONRequestBody) (*api.Zone, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Zone](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteZone delete DNS zone
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-zone
|
||||
func (a *DNSZonesAPI) DeleteZone(ctx context.Context, zoneID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRecords list all DNS records in a zone
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-records
|
||||
func (a *DNSZonesAPI) ListRecords(ctx context.Context, zoneID string) ([]api.DNSRecord, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.DNSRecord](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// GetRecord get DNS record info
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-record
|
||||
func (a *DNSZonesAPI) GetRecord(ctx context.Context, zoneID, recordID string) (*api.DNSRecord, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSRecord](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateRecord create new DNS record in a zone
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-record
|
||||
func (a *DNSZonesAPI) CreateRecord(ctx context.Context, zoneID string, request api.PostApiDnsZonesZoneIdRecordsJSONRequestBody) (*api.DNSRecord, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones/"+zoneID+"/records", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSRecord](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateRecord update DNS record info
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-record
|
||||
func (a *DNSZonesAPI) UpdateRecord(ctx context.Context, zoneID, recordID string, request api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody) (*api.DNSRecord, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID+"/records/"+recordID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSRecord](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteRecord delete DNS record
|
||||
// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-record
|
||||
func (a *DNSZonesAPI) DeleteRecord(ctx context.Context, zoneID, recordID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
460
shared/management/client/rest/dns_zones_test.go
Normal file
460
shared/management/client/rest/dns_zones_test.go
Normal file
@@ -0,0 +1,460 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testZone = api.Zone{
|
||||
Id: "zone123",
|
||||
Name: "test-zone",
|
||||
Domain: "example.com",
|
||||
Enabled: true,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
}
|
||||
|
||||
testDNSRecord = api.DNSRecord{
|
||||
Id: "record123",
|
||||
Name: "www",
|
||||
Content: "192.168.1.1",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Ttl: 300,
|
||||
}
|
||||
)
|
||||
|
||||
func TestDNSZone_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal([]api.Zone{testZone})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.ListZones(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testZone, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.ListZones(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testZone)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.GetZone(context.Background(), "zone123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testZone, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.GetZone(context.Background(), "zone123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-zone", req.Name)
|
||||
assert.Equal(t, "example.com", req.Domain)
|
||||
retBytes, _ := json.Marshal(testZone)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
enabled := true
|
||||
ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "test-zone",
|
||||
Domain: "example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testZone, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{
|
||||
Name: "test-zone",
|
||||
Domain: "example.com",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Invalid request", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated-zone", req.Name)
|
||||
retBytes, _ := json.Marshal(testZone)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
enabled := true
|
||||
ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{
|
||||
Name: "updated-zone",
|
||||
Domain: "example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"group1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testZone, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{
|
||||
Name: "updated-zone",
|
||||
Domain: "example.com",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Invalid request", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.DNSZones.DeleteZone(context.Background(), "zone123")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZone_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.DNSZones.DeleteZone(context.Background(), "zone123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal([]api.DNSRecord{testDNSRecord})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.ListRecords(context.Background(), "zone123")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testDNSRecord, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Zone not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.ListRecords(context.Background(), "zone123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Zone not found", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
retBytes, _ := json.Marshal(testDNSRecord)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testDNSRecord, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "www", req.Name)
|
||||
assert.Equal(t, "192.168.1.1", req.Content)
|
||||
assert.Equal(t, api.DNSRecordTypeA, req.Type)
|
||||
retBytes, _ := json.Marshal(testDNSRecord)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "www",
|
||||
Content: "192.168.1.1",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Ttl: 300,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testDNSRecord, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
|
||||
Name: "www",
|
||||
Content: "192.168.1.1",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Ttl: 300,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Invalid record", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "api", req.Name)
|
||||
assert.Equal(t, "192.168.1.2", req.Content)
|
||||
retBytes, _ := json.Marshal(testDNSRecord)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "api",
|
||||
Content: "192.168.1.2",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Ttl: 300,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testDNSRecord, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
|
||||
Name: "api",
|
||||
Content: "192.168.1.2",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Ttl: 300,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Invalid record", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSRecord_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSZones_Integration(t *testing.T) {
|
||||
enabled := true
|
||||
zoneReq := api.ZoneRequest{
|
||||
Name: "test-zone",
|
||||
Domain: "test.example.com",
|
||||
Enabled: &enabled,
|
||||
EnableSearchDomain: false,
|
||||
DistributionGroups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
||||
}
|
||||
|
||||
recordReq := api.DNSRecordRequest{
|
||||
Name: "api.test.example.com",
|
||||
Content: "192.168.1.100",
|
||||
Type: api.DNSRecordTypeA,
|
||||
Ttl: 300,
|
||||
}
|
||||
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
zone, err := c.DNSZones.CreateZone(context.Background(), zoneReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-zone", zone.Name)
|
||||
assert.Equal(t, "test.example.com", zone.Domain)
|
||||
|
||||
zones, err := c.DNSZones.ListZones(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *zone, zones[0])
|
||||
|
||||
getZone, err := c.DNSZones.GetZone(context.Background(), zone.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *zone, *getZone)
|
||||
|
||||
zoneReq.Name = "updated-zone"
|
||||
updatedZone, err := c.DNSZones.UpdateZone(context.Background(), zone.Id, zoneReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated-zone", updatedZone.Name)
|
||||
|
||||
record, err := c.DNSZones.CreateRecord(context.Background(), zone.Id, recordReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "api.test.example.com", record.Name)
|
||||
assert.Equal(t, "192.168.1.100", record.Content)
|
||||
|
||||
records, err := c.DNSZones.ListRecords(context.Background(), zone.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *record, records[0])
|
||||
|
||||
getRecord, err := c.DNSZones.GetRecord(context.Background(), zone.Id, record.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *record, *getRecord)
|
||||
|
||||
recordReq.Name = "www.test.example.com"
|
||||
updatedRecord, err := c.DNSZones.UpdateRecord(context.Background(), zone.Id, record.Id, recordReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "www.test.example.com", updatedRecord.Name)
|
||||
|
||||
err = c.DNSZones.DeleteRecord(context.Background(), zone.Id, record.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
records, err = c.DNSZones.ListRecords(context.Background(), zone.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
|
||||
err = c.DNSZones.DeleteZone(context.Background(), zone.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
zones, err = c.DNSZones.ListZones(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, zones, 0)
|
||||
})
|
||||
}
|
||||
@@ -25,6 +25,8 @@ tags:
|
||||
description: Interact with and view information about routes.
|
||||
- name: DNS
|
||||
description: Interact with and view information about DNS configuration.
|
||||
- name: DNS Zones
|
||||
description: Interact with and view information about custom DNS zones.
|
||||
- name: Events
|
||||
description: View information about the account and network events.
|
||||
- name: Accounts
|
||||
@@ -1779,6 +1781,100 @@ components:
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
required:
|
||||
- disabled_management_groups
|
||||
ZoneRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Zone name identifier
|
||||
type: string
|
||||
maxLength: 255
|
||||
minLength: 1
|
||||
example: Office Zone
|
||||
domain:
|
||||
description: Zone domain (FQDN)
|
||||
type: string
|
||||
example: example.com
|
||||
enabled:
|
||||
description: Zone status
|
||||
type: boolean
|
||||
default: true
|
||||
enable_search_domain:
|
||||
description: Enable this zone as a search domain
|
||||
type: boolean
|
||||
example: false
|
||||
distribution_groups:
|
||||
description: Group IDs that defines groups of peers that will resolve this zone
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
required:
|
||||
- name
|
||||
- domain
|
||||
- enable_search_domain
|
||||
- distribution_groups
|
||||
Zone:
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Zone ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
records:
|
||||
description: DNS records associated with this zone
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/DNSRecord'
|
||||
required:
|
||||
- id
|
||||
- enabled
|
||||
- records
|
||||
- $ref: '#/components/schemas/ZoneRequest'
|
||||
DNSRecordType:
|
||||
type: string
|
||||
description: DNS record type
|
||||
enum:
|
||||
- A
|
||||
- AAAA
|
||||
- CNAME
|
||||
example: A
|
||||
DNSRecordRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: FQDN for the DNS record. Must be a subdomain within or match the zone's domain.
|
||||
type: string
|
||||
example: www.example.com
|
||||
type:
|
||||
$ref: '#/components/schemas/DNSRecordType'
|
||||
content:
|
||||
description: DNS record content (IP address for A/AAAA, domain for CNAME)
|
||||
type: string
|
||||
maxLength: 255
|
||||
minLength: 1
|
||||
example: 192.168.1.1
|
||||
ttl:
|
||||
description: Time to live in seconds
|
||||
type: integer
|
||||
minimum: 0
|
||||
example: 300
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- content
|
||||
- ttl
|
||||
DNSRecord:
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: DNS record ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
required:
|
||||
- id
|
||||
- $ref: '#/components/schemas/DNSRecordRequest'
|
||||
Event:
|
||||
type: object
|
||||
properties:
|
||||
@@ -4733,6 +4829,347 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/dns/zones:
|
||||
get:
|
||||
summary: List all DNS Zones
|
||||
description: Returns a list of all custom DNS zones
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of DNS Zones
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Zone'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a DNS Zone
|
||||
description: Creates a new custom DNS zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
description: A DNS zone object
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/ZoneRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Object of the created DNS Zone
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Zone'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/dns/zones/{zoneId}:
|
||||
get:
|
||||
summary: Retrieve a DNS Zone
|
||||
description: Returns information about a specific DNS zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Object of a DNS Zone
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Zone'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
put:
|
||||
summary: Update a DNS Zone
|
||||
description: Updates a custom DNS zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
requestBody:
|
||||
description: A DNS zone object
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/ZoneRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Object of the updated DNS Zone
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Zone'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a DNS Zone
|
||||
description: Deletes a custom DNS zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
responses:
|
||||
'200':
|
||||
description: Zone deletion successful
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/dns/zones/{zoneId}/records:
|
||||
get:
|
||||
summary: List all DNS Records
|
||||
description: Returns a list of all DNS records in a zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of DNS Records
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/DNSRecord'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a DNS Record
|
||||
description: Creates a new DNS record in a zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
requestBody:
|
||||
description: A DNS record object
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/DNSRecordRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Object of the created DNS Record
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DNSRecord'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/dns/zones/{zoneId}/records/{recordId}:
|
||||
get:
|
||||
summary: Retrieve a DNS Record
|
||||
description: Returns information about a specific DNS record
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
- in: path
|
||||
name: recordId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a DNS record
|
||||
example: chacbco6lnnbn6cg5s92
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Object of a DNS Record
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DNSRecord'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
put:
|
||||
summary: Update a DNS Record
|
||||
description: Updates a DNS record in a zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
- in: path
|
||||
name: recordId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a DNS record
|
||||
example: chacbco6lnnbn6cg5s92
|
||||
requestBody:
|
||||
description: A DNS record object
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/DNSRecordRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Object of the updated DNS Record
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DNSRecord'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a DNS Record
|
||||
description: Deletes a DNS record from a zone
|
||||
tags: [ DNS Zones ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: zoneId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a zone
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
- in: path
|
||||
name: recordId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a DNS record
|
||||
example: chacbco6lnnbn6cg5s92
|
||||
responses:
|
||||
'200':
|
||||
description: Record deletion successful
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/events/audit:
|
||||
get:
|
||||
summary: List all Audit Events
|
||||
|
||||
@@ -12,6 +12,13 @@ const (
|
||||
TokenAuthScopes = "TokenAuth.Scopes"
|
||||
)
|
||||
|
||||
// Defines values for DNSRecordType.
|
||||
const (
|
||||
DNSRecordTypeA DNSRecordType = "A"
|
||||
DNSRecordTypeAAAA DNSRecordType = "AAAA"
|
||||
DNSRecordTypeCNAME DNSRecordType = "CNAME"
|
||||
)
|
||||
|
||||
// Defines values for EventActivityCode.
|
||||
const (
|
||||
EventActivityCodeAccountCreate EventActivityCode = "account.create"
|
||||
@@ -427,6 +434,42 @@ type CreateSetupKeyRequest struct {
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// DNSRecord defines model for DNSRecord.
|
||||
type DNSRecord struct {
|
||||
// Content DNS record content (IP address for A/AAAA, domain for CNAME)
|
||||
Content string `json:"content"`
|
||||
|
||||
// Id DNS record ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Name FQDN for the DNS record. Must be a subdomain within or match the zone's domain.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Ttl Time to live in seconds
|
||||
Ttl int `json:"ttl"`
|
||||
|
||||
// Type DNS record type
|
||||
Type DNSRecordType `json:"type"`
|
||||
}
|
||||
|
||||
// DNSRecordRequest defines model for DNSRecordRequest.
|
||||
type DNSRecordRequest struct {
|
||||
// Content DNS record content (IP address for A/AAAA, domain for CNAME)
|
||||
Content string `json:"content"`
|
||||
|
||||
// Name FQDN for the DNS record. Must be a subdomain within or match the zone's domain.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Ttl Time to live in seconds
|
||||
Ttl int `json:"ttl"`
|
||||
|
||||
// Type DNS record type
|
||||
Type DNSRecordType `json:"type"`
|
||||
}
|
||||
|
||||
// DNSRecordType DNS record type
|
||||
type DNSRecordType string
|
||||
|
||||
// DNSSettings defines model for DNSSettings.
|
||||
type DNSSettings struct {
|
||||
// DisabledManagementGroups Groups whose DNS management is disabled
|
||||
@@ -1999,6 +2042,48 @@ type UserRequest struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// Zone defines model for Zone.
|
||||
type Zone struct {
|
||||
// DistributionGroups Group IDs that defines groups of peers that will resolve this zone
|
||||
DistributionGroups []string `json:"distribution_groups"`
|
||||
|
||||
// Domain Zone domain (FQDN)
|
||||
Domain string `json:"domain"`
|
||||
|
||||
// EnableSearchDomain Enable this zone as a search domain
|
||||
EnableSearchDomain bool `json:"enable_search_domain"`
|
||||
|
||||
// Enabled Zone status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Id Zone ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Name Zone name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Records DNS records associated with this zone
|
||||
Records []DNSRecord `json:"records"`
|
||||
}
|
||||
|
||||
// ZoneRequest defines model for ZoneRequest.
|
||||
type ZoneRequest struct {
|
||||
// DistributionGroups Group IDs that defines groups of peers that will resolve this zone
|
||||
DistributionGroups []string `json:"distribution_groups"`
|
||||
|
||||
// Domain Zone domain (FQDN)
|
||||
Domain string `json:"domain"`
|
||||
|
||||
// EnableSearchDomain Enable this zone as a search domain
|
||||
EnableSearchDomain bool `json:"enable_search_domain"`
|
||||
|
||||
// Enabled Zone status
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
|
||||
// Name Zone name identifier
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic.
|
||||
type GetApiEventsNetworkTrafficParams struct {
|
||||
// Page Page number
|
||||
@@ -2083,6 +2168,18 @@ type PutApiDnsNameserversNsgroupIdJSONRequestBody = NameserverGroupRequest
|
||||
// PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType.
|
||||
type PutApiDnsSettingsJSONRequestBody = DNSSettings
|
||||
|
||||
// PostApiDnsZonesJSONRequestBody defines body for PostApiDnsZones for application/json ContentType.
|
||||
type PostApiDnsZonesJSONRequestBody = ZoneRequest
|
||||
|
||||
// PutApiDnsZonesZoneIdJSONRequestBody defines body for PutApiDnsZonesZoneId for application/json ContentType.
|
||||
type PutApiDnsZonesZoneIdJSONRequestBody = ZoneRequest
|
||||
|
||||
// PostApiDnsZonesZoneIdRecordsJSONRequestBody defines body for PostApiDnsZonesZoneIdRecords for application/json ContentType.
|
||||
type PostApiDnsZonesZoneIdRecordsJSONRequestBody = DNSRecordRequest
|
||||
|
||||
// PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody defines body for PutApiDnsZonesZoneIdRecordsRecordId for application/json ContentType.
|
||||
type PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody = DNSRecordRequest
|
||||
|
||||
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
|
||||
type PostApiGroupsJSONRequestBody = GroupRequest
|
||||
|
||||
|
||||
@@ -252,3 +252,13 @@ func NewOperationNotFoundError(operation operations.Operation) error {
|
||||
func NewRouteNotFoundError(routeID string) error {
|
||||
return Errorf(NotFound, "route: %s not found", routeID)
|
||||
}
|
||||
|
||||
// NewZoneNotFoundError creates a new Error with NotFound type for a missing dns zone.
|
||||
func NewZoneNotFoundError(zoneID string) error {
|
||||
return Errorf(NotFound, "zone: %s not found", zoneID)
|
||||
}
|
||||
|
||||
// NewDNSRecordNotFoundError creates a new Error with NotFound type for a missing dns record.
|
||||
func NewDNSRecordNotFoundError(recordID string) error {
|
||||
return Errorf(NotFound, "dns record: %s not found", recordID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user