Compare commits

...

3 Commits

Author SHA1 Message Date
Pascal Fischer
0c6c5fdc70 extend example 2024-03-18 15:31:47 +01:00
Pascal Fischer
27c3a4c5d6 simplify storage inheritance 2024-03-14 11:42:25 +01:00
Pascal Fischer
f31b06fc92 add example setup for management refactor 2024-03-13 23:07:00 +01:00
64 changed files with 5825 additions and 0 deletions

View File

@@ -0,0 +1,641 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// GRPCServer an instance of a Management gRPC API server
type GRPCServer struct {
accountManager AccountManager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *Config
turnCredentialsManager TURNCredentialsManager
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
var jwtValidator *jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
return nil, err
}
}
var audience, userIDClaim string
if config.HttpConfig != nil {
audience = config.HttpConfig.AuthAudience
userIDClaim = config.HttpConfig.AuthUserIDClaim
}
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
}, nil
}
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
}
now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second())
nanos := int32(now.Nanosecond())
expiresAt := &timestamp.Timestamp{Seconds: secs, Nanos: nanos}
return &proto.ServerKeyResponse{
Key: s.wgKey.PublicKey().String(),
ExpiresAt: expiresAt,
}, nil
}
func getRealIP(ctx context.Context) net.IP {
if addr, ok := realip.FromContext(ctx); ok {
return net.IP(addr.AsSlice())
}
return nil
}
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
reqStart := time.Now()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
realIP := getRealIP(srv.Context())
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(req, syncReq)
if err != nil {
return err
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
if err != nil {
return mapError(err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
}
updates := s.peersUpdateManager.CreateChannel(peer.ID)
s.ephemeralManager.OnPeerConnected(peer)
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
}
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
// keep a connection to the peer and send updates when available
for {
select {
// condition when there are some updates
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
}
if !open {
log.Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer)
return nil
}
log.Debugf("received an update for peer %s", peerKey.String())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.Debugf("sent an update to peer %s", peerKey.String())
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(peer)
return srv.Context().Err()
}
}
}
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
s.ephemeralManager.OnPeerDisconnected(peer)
}
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error())
}
return claims.UserId, nil
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(err error) error {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, e.Message)
case internalStatus.Unauthorized:
return status.Errorf(codes.PermissionDenied, e.Message)
case internalStatus.Unauthenticated:
return status.Errorf(codes.PermissionDenied, e.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.FailedPrecondition, e.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, e.Message)
default:
}
}
log.Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
osVersion := loginReq.GetMeta().GetOSVersion()
if osVersion == "" {
osVersion = loginReq.GetMeta().GetCore()
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses()))
for _, addr := range loginReq.GetMeta().GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
continue
}
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
NetIP: netAddr,
Mac: addr.GetMac(),
})
}
return nbpeer.PeerSystemMeta{
Hostname: loginReq.GetMeta().GetHostname(),
GoOS: loginReq.GetMeta().GetGoOS(),
Kernel: loginReq.GetMeta().GetKernel(),
Platform: loginReq.GetMeta().GetPlatform(),
OS: loginReq.GetMeta().GetOS(),
OSVersion: osVersion,
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
UIVersion: loginReq.GetMeta().GetUiVersion(),
KernelVersion: loginReq.GetMeta().GetKernelVersion(),
NetworkAddresses: networkAddresses,
SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(),
SystemProductName: loginReq.GetMeta().GetSysProductName(),
SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(),
Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(),
},
}
}
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
}
return peerKey, nil
}
// Login endpoint first checks whether peer is registered under any account
// In case it is, the login is successful
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now()
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
}
}()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
realIP := getRealIP(ctx)
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(req, loginReq)
if err != nil {
return nil, err
}
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
return nil, msg
}
userID := ""
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
// or it uses a setup key to register.
if loginReq.GetJwtToken() != "" {
userID, err = s.validateToken(loginReq.GetJwtToken())
if err != nil {
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
return nil, err
}
}
var sshKey []byte
if loginReq.GetPeerKeys() != nil {
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
})
if err != nil {
log.Warnf("failed logging in peer %s", peerKey)
return nil, mapError(err)
}
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(peer)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
switch configProto {
case UDP:
return proto.HostConfig_UDP
case DTLS:
return proto.HostConfig_DTLS
case HTTP:
return proto.HostConfig_HTTP
case HTTPS:
return proto.HostConfig_HTTPS
case TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
Uri: stun.URI,
Protocol: ToResponseProto(stun.Proto),
})
}
var turns []*proto.ProtectedHostConfig
for _, turn := range config.TURNConfig.Turns {
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Username
password = turnCredentials.Password
} else {
username = turn.Username
password = turn.Password
}
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: turn.URI,
Protocol: ToResponseProto(turn.Proto),
},
User: username,
Password: password,
})
}
return &proto.WiretrusteeConfig{
Stuns: stuns,
Turns: turns,
Signal: &proto.HostConfig{
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
},
}
}
func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
}
}
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers {
fqdn := rPeer.FQDN(dnsName)
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig,
RemotePeers: remotePeers,
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
}
}
// IsHealthy indicates whether the service is healthy
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
return &proto.Empty{}, nil
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
creds := s.turnCredentialsManager.GenerateCredentials()
turnCredentials = &creds
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
return status.Errorf(codes.Internal, "error handling request")
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
log.Errorf("failed sending SyncResponse %v", err)
return status.Errorf(codes.Internal, "error handling request")
}
return nil
}
// GetDeviceAuthorizationFlow returns a device authorization flow information
// This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
}
flowInfoResp := &proto.DeviceAuthorizationFlow{
Provider: proto.DeviceAuthorizationFlowProvider(provider),
ProviderConfig: &proto.ProviderConfig{
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
},
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.PKCEAuthorizationFlow == nil {
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
}
flowInfoResp := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
},
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}

View File

@@ -0,0 +1,122 @@
package http
import (
"context"
"fmt"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/cors"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/refactor/resources/peers"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const apiPrefix = "/api"
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}
type DefaultAPIHandler struct {
Router *mux.Router
AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation
AuthCfg AuthCfg
}
// EmptyObject is an empty struct used to return empty JSON object
type EmptyObject struct {
}
// NewDefaultAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewDefaultAPIHandler(ctx context.Context, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,
claimsExtractor,
authCfg.Audience,
authCfg.UserIDClaim,
)
corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,
accountManager.GetUser)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := DefaultAPIHandler{
Router: router,
AccountManager: accountManager,
geolocationManager: LocationManager,
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
peers.RegisterPeersEndpoints(api.Router)
// api.addAccountsEndpoint()
// api.addPeersEndpoint()
// api.addUsersEndpoint()
// api.addUsersTokensEndpoint()
// api.addSetupKeysEndpoint()
// api.addRulesEndpoint()
// api.addPoliciesEndpoint()
// api.addGroupsEndpoint()
// api.addRoutesEndpoint()
// api.addDNSNameserversEndpoint()
// api.addDNSSettingEndpoint()
// api.addEventsEndpoint()
// api.addPostureCheckEndpoint()
// api.addLocationsEndpoint()
err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil { // we may have wildcard routes from integrations without methods, skip them for now
methods = []string{}
}
for _, method := range methods {
template, err := route.GetPathTemplate()
if err != nil {
return err
}
err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return rootRouter, nil
}

View File

@@ -0,0 +1,7 @@
package: api
generate:
models: true
embedded-spec: false
output: types.gen.go
compatibility:
always-prefix-enum-values: true

View File

@@ -0,0 +1,16 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@4a1477f6a8ba6ca8115cc23bb2fb67f0b9fca18e
oapi-codegen --config cfg.yaml openapi.yml
cd "$old_pwd"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
package specs

View File

@@ -0,0 +1,178 @@
package mesh
import (
"github.com/netbirdio/management-integrations/integrations"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/refactor/api/http"
"github.com/netbirdio/netbird/management/refactor/resources/network"
networkTypes "github.com/netbirdio/netbird/management/refactor/resources/network/types"
"github.com/netbirdio/netbird/management/refactor/resources/peers"
peerTypes "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
"github.com/netbirdio/netbird/management/refactor/resources/policies"
"github.com/netbirdio/netbird/management/refactor/resources/routes"
"github.com/netbirdio/netbird/management/refactor/resources/settings"
"github.com/netbirdio/netbird/management/refactor/resources/users"
"github.com/netbirdio/netbird/management/refactor/store"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
)
type Controller interface {
LoginPeer()
SyncPeer()
}
type DefaultController struct {
store store.Store
peersManager peers.Manager
userManager users.Manager
policiesManager policies.Manager
settingsManager settings.Manager
networkManager network.Manager
routesManager routes.Manager
}
func NewDefaultController() *DefaultController {
storeStore, _ := store.NewDefaultStore(store.SqliteStoreEngine, "", nil)
settingsManager := settings.NewDefaultManager(storeStore)
networkManager := network.NewDefaultManager()
peersManager := peers.NewDefaultManager(storeStore, settingsManager)
routesManager := routes.NewDefaultManager(storeStore, peersManager)
usersManager := users.NewDefaultManager(storeStore, peersManager)
policiesManager := policies.NewDefaultManager(storeStore, peersManager)
apiHandler, _ := http.NewDefaultAPIHandler()
peersManager, settingsManager, usersManager, policiesManager, storeStore, apiHandler = integrations.InjectCloud(peersManager, policiesManager, settingsManager, usersManager, storeStore)
return &DefaultController{
store: storeStore,
peersManager: peersManager,
userManager: usersManager,
policiesManager: policiesManager,
settingsManager: settingsManager,
networkManager: networkManager,
routesManager: routesManager,
}
}
func (c *DefaultController) LoginPeer(login peerTypes.PeerLogin) (*peerTypes.Peer, *networkTypes.NetworkMap, error) {
peer, err := c.peersManager.GetPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
if peer.AddedWithSSOLogin() {
user, err := c.userManager.GetUser(peer.GetUserID())
if err != nil {
return nil, nil, err
}
if user.IsBlocked() {
return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked")
}
}
settings, err := c.settingsManager.GetSettings(peer.GetAccountID())
if err != nil {
return nil, nil, err
}
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
if peerLoginExpired(peer, settings) {
err = checkAuth(login.UserID, peer)
if err != nil {
return nil, nil, err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
peer.UpdateLastLogin()
updateRemotePeers = true
shouldStorePeer = true
pm.eventsManager.StoreEvent(login.UserID, peer.GetID(), peer.GetAccountID(), activity.UserLoggedInPeer, peer.EventMeta(pm.accountManager.GetDNSDomain()))
}
if peer.UpdateMetaIfNew(login.Meta) {
shouldStorePeer = true
}
if peer.CheckAndUpdatePeerSSHKey(login.SSHKey) {
shouldStorePeer = true
}
if shouldStorePeer {
err := pm.repository.updatePeer(peer)
if err != nil {
return nil, nil, err
}
}
if updateRemotePeers {
am.updateAccountPeers(account)
}
return peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil
}
func (c *DefaultController) SyncPeer() {
}
func (c *DefaultController) GetPeerNetworkMap(accountID, peerID, dnsDomain string) (*networkTypes.NetworkMap, error) {
unlock := c.store.AcquireAccountLock(accountID)
defer unlock()
network, err := c.networkManager.GetNetwork(accountID)
if err != nil {
return nil, err
}
peer, err := c.peersManager.GetNetworkPeerByID(peerID)
if err != nil {
return &networkTypes.NetworkMap{
Network: network.Copy(),
}, nil
}
aclPeers, firewallRules := c.policiesManager.GetAccessiblePeersAndFirewallRules(peerID)
// exclude expired peers
var peersToConnect []*peerTypes.Peer
var expiredPeers []*peerTypes.Peer
accSettings, _ := c.settingsManager.GetSettings(peer.GetAccountID())
for _, p := range aclPeers {
expired, _ := p.LoginExpired(accSettings.GetPeerLoginExpiration())
if accSettings.GetPeerLoginExpirationEnabled() && expired {
expiredPeers = append(expiredPeers, &p)
continue
}
peersToConnect = append(peersToConnect, &p)
}
routesUpdate := c.routesManager.GetRoutesToSync(peerID, peersToConnect, accountID)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &networkTypes.NetworkMap{
Peers: peersToConnect,
Network: network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}, nil
}

View File

@@ -0,0 +1 @@
package dns

View File

@@ -0,0 +1 @@
package dns

View File

@@ -0,0 +1 @@
package dns

View File

@@ -0,0 +1,11 @@
package types
// Config represents a dns configuration that is exchanged between management and peers
type Config struct {
// ServiceEnable indicates if the service should be enabled
ServiceEnable bool
// NameServerGroups contains a list of nameserver group
NameServerGroups []*NameServerGroup
// CustomZones contains a list of custom zone
CustomZones []CustomZone
}

View File

@@ -0,0 +1,9 @@
package types
// CustomZone represents a custom zone to be resolved by the dns server
type CustomZone struct {
// Domain is the zone's domain
Domain string
// Records custom zone records
Records []SimpleRecord
}

View File

@@ -0,0 +1,75 @@
package types
import "net/netip"
// NameServerType nameserver type
type NameServerType int
// NameServer represents a DNS nameserver
type NameServer struct {
// IP address of nameserver
IP netip.Addr
// NSType nameserver type
NSType NameServerType
// Port nameserver listening port
Port int
}
// Copy copies a nameserver object
func (n *NameServer) Copy() *NameServer {
return &NameServer{
IP: n.IP,
NSType: n.NSType,
Port: n.Port,
}
}
// IsEqual compares one nameserver with the other
func (n *NameServer) IsEqual(other *NameServer) bool {
return other.IP == n.IP &&
other.NSType == n.NSType &&
other.Port == n.Port
}
func compareNameServerList(list, other []NameServer) bool {
if len(list) != len(other) {
return false
}
for _, ns := range list {
if !containsNameServer(ns, other) {
return false
}
}
return true
}
func containsNameServer(element NameServer, list []NameServer) bool {
for _, ns := range list {
if ns.IsEqual(&element) {
return true
}
}
return false
}
func compareGroupsList(list, other []string) bool {
if len(list) != len(other) {
return false
}
for _, id := range list {
match := false
for _, otherID := range other {
if id == otherID {
match = true
break
}
}
if !match {
return false
}
}
return true
}

View File

@@ -0,0 +1,65 @@
package types
type NameServerGroup interface {
}
type DefaultNameServerGroup struct {
// ID identifier of group
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
// Name group name
Name string
// Description group description
Description string
// NameServers list of nameservers
NameServers []NameServer `gorm:"serializer:json"`
// Groups list of peer group IDs to distribute the nameservers information
Groups []string `gorm:"serializer:json"`
// Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool
// Domains indicate the dns query domains to use with this nameserver group
Domains []string `gorm:"serializer:json"`
// Enabled group status
Enabled bool
// SearchDomainsEnabled indicates whether to add match domains to search domains list or not
SearchDomainsEnabled bool
}
// EventMeta returns activity event meta related to the nameserver group
func (g *DefaultNameServerGroup) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
}
// Copy copies a nameserver group object
func (g *DefaultNameServerGroup) Copy() *DefaultNameServerGroup {
nsGroup := &DefaultNameServerGroup{
ID: g.ID,
Name: g.Name,
Description: g.Description,
NameServers: make([]NameServer, len(g.NameServers)),
Groups: make([]string, len(g.Groups)),
Enabled: g.Enabled,
Primary: g.Primary,
Domains: make([]string, len(g.Domains)),
SearchDomainsEnabled: g.SearchDomainsEnabled,
}
copy(nsGroup.NameServers, g.NameServers)
copy(nsGroup.Groups, g.Groups)
copy(nsGroup.Domains, g.Domains)
return nsGroup
}
// IsEqual compares one nameserver group with the other
func (g *DefaultNameServerGroup) IsEqual(other *DefaultNameServerGroup) bool {
return other.ID == g.ID &&
other.Name == g.Name &&
other.Description == g.Description &&
other.Primary == g.Primary &&
other.SearchDomainsEnabled == g.SearchDomainsEnabled &&
compareNameServerList(g.NameServers, other.NameServers) &&
compareGroupsList(g.Groups, other.Groups) &&
compareGroupsList(g.Domains, other.Domains)
}

View File

@@ -0,0 +1,7 @@
package types
type Settings interface {
}
type DefaultSettings struct {
}

View File

@@ -0,0 +1,53 @@
package types
import (
"fmt"
"net"
"github.com/miekg/dns"
)
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
type SimpleRecord struct {
// Name domain name
Name string
// Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants
Type int
// Class dns class, currently use the DefaultClass for all records
Class string
// TTL time-to-live for the record
TTL int
// RData is the actual value resolved in a dns query
RData string
}
// String returns a string of the simple record formatted as:
// <Name> <TTL> <Class> <Type> <RDATA>
func (s SimpleRecord) String() string {
fqdn := dns.Fqdn(s.Name)
return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData)
}
// Len returns the length of the RData field, based on its type
func (s SimpleRecord) Len() uint16 {
emptyString := s.RData == ""
switch s.Type {
case 1:
if emptyString {
return 0
}
return net.IPv4len
case 5:
if emptyString || s.RData == "." {
return 1
}
return uint16(len(s.RData) + 1)
case 28:
if emptyString {
return 0
}
return net.IPv6len
default:
return 0
}
}

View File

@@ -0,0 +1 @@
package groups

View File

@@ -0,0 +1 @@
package groups

View File

@@ -0,0 +1 @@
package groups

View File

@@ -0,0 +1,23 @@
package types
type Group interface {
}
type DefaultGroup struct {
// ID of the group
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name visible in the UI
Name string
// Issued of the group
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}

View File

@@ -0,0 +1 @@
package network

View File

@@ -0,0 +1,19 @@
package network
import "github.com/netbirdio/netbird/management/refactor/resources/network/types"
type Manager interface {
GetNetwork(accountID string) (types.Network, error)
}
type DefaultManager struct {
}
func NewDefaultManager() *DefaultManager {
return &DefaultManager{}
}
func (d DefaultManager) GetNetwork(accountID string) (types.Network, error) {
// TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1 @@
package network

View File

@@ -0,0 +1,70 @@
package types
import (
"math/rand"
"net"
"sync"
"time"
"github.com/c-robinson/iplib"
"github.com/rs/xid"
)
const (
// SubnetSize is a size of the subnet of the global network, e.g. 100.77.0.0/16
SubnetSize = 16
// NetSize is a global network size 100.64.0.0/10
NetSize = 10
)
type Network struct {
Identifier string `json:"id"`
Net net.IPNet `gorm:"serializer:gob"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64
mu sync.Mutex `json:"-" gorm:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(sub))
return &Network{
Identifier: xid.New().String(),
Net: sub[intn].IPNet,
Dns: "",
Serial: 0}
}
// IncSerial increments Serial by 1 reflecting that the network state has been changed
func (n *Network) IncSerial() {
n.mu.Lock()
defer n.mu.Unlock()
n.Serial++
}
// CurrentSerial returns the Network.Serial of the network (latest state id)
func (n *Network) CurrentSerial() uint64 {
n.mu.Lock()
defer n.mu.Unlock()
return n.Serial
}
func (n *Network) Copy() *Network {
return &Network{
Identifier: n.Identifier,
Net: n.Net,
Dns: n.Dns,
Serial: n.Serial,
}
}

View File

@@ -0,0 +1,18 @@
package types
import (
peerTypes "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
policyTypes "github.com/netbirdio/netbird/management/refactor/resources/policies/types"
routeTypes "github.com/netbirdio/netbird/management/refactor/resources/routes/types"
nbdns "github.com/netbirdio/netbird/dns"
)
type NetworkMap struct {
Peers []*peerTypes.Peer
Network *Network
Routes []*routeTypes.Route
DNSConfig nbdns.Config
OfflinePeers []*peerTypes.Peer
FirewallRules []*policyTypes.FirewallRule
}

View File

@@ -0,0 +1,317 @@
package peers
import (
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/mux"
http2 "github.com/netbirdio/netbird/management/refactor/api/http"
peerTypes "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
"github.com/netbirdio/netbird/management/refactor/store"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
func RegisterPeersEndpoints(manager Manager, router *mux.Router) {
peersHandler := NewDefaultPeersHandler(manager, apiHandler.AuthCfg)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
}
// DefaultPeersHandler is a handler that returns peers of the account
type DefaultPeersHandler struct {
peersManager Manager
store store.Store
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewDefaultPeersHandler creates a new PeersHandler HTTP handler
func NewDefaultPeersHandler(manager Manager, authCfg AuthCfg) *DefaultPeersHandler {
return &DefaultPeersHandler{
peersManager: manager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
func (h *DefaultPeersHandler) checkPeerStatus(peer *peerTypes.Peer) (*peerTypes.Peer, error) {
peerToReturn := peer.Copy()
if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if !h.accountManager.HasConnectedChannel(peer.ID) {
peerToReturn.Status.Connected = false
}
}
return peerToReturn, nil
}
func (h *DefaultPeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.peersManager.GetPeerByID(account.Id, peerID, userID)
if err != nil {
util.WriteError(err, w)
return
}
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain())
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers))
}
func (h *DefaultPeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled}
if req.ApprovalRequired != nil {
update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired}
}
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
if err != nil {
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain())
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers))
}
func (h *DefaultPeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(accountID, peerID, userID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, http2.EmptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *DefaultPeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
switch r.Method {
case http.MethodDelete:
h.deletePeer(account.Id, user.Id, peerID, w)
return
case http.MethodPut:
h.updatePeer(account, user, peerID, w, r)
return
case http.MethodGet:
h.getPeer(account, peerID, user.Id, w)
return
default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}
// GetAllPeers returns a list of all peers associated with a provided account
func (h *DefaultPeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers := h.accessiblePeersNumber(account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
}
util.WriteJSONObject(w, respBody)
return
default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}
func (h *DefaultPeersHandler) accessiblePeersNumber(account *server.Account, peerID string) int {
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain())
return len(netMap.Peers) + len(netMap.OfflinePeers)
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
}
for _, p := range netMap.OfflinePeers {
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
}
return accessiblePeers
}
func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
for _, group := range groups {
_, ok := groupsChecked[group.ID]
if ok {
continue
}
groupsChecked[group.ID] = struct{}{}
for _, pk := range group.Peers {
if pk == peerID {
info := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
groupsInfo = append(groupsInfo, info)
break
}
}
}
return groupsInfo
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
return &api.Peer{
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: &peer.Status.RequiresApproval,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
}
}
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
return &api.PeerBatch{
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount,
ApprovalRequired: &peer.Status.RequiresApproval,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
}
}
func fqdn(peer *nbpeer.Peer, dnsDomain string) string {
fqdn := peer.FQDN(dnsDomain)
if fqdn == "" {
return peer.DNSLabel
} else {
return fqdn
}
}

View File

@@ -0,0 +1,51 @@
package peers
import (
"github.com/netbirdio/netbird/management/refactor/resources/peers/types"
"github.com/netbirdio/netbird/management/refactor/resources/settings"
)
type Manager interface {
GetPeerByPubKey(pubKey string) (types.Peer, error)
GetPeerByID(id string) (types.Peer, error)
GetNetworkPeerByID(id string) (types.Peer, error)
GetNetworkPeersInAccount(id string) ([]types.Peer, error)
}
type DefaultManager struct {
repository Repository
settingsManager settings.Manager
}
func NewDefaultManager(repository Repository, settingsManager settings.Manager) *DefaultManager {
return &DefaultManager{
repository: repository,
settingsManager: settingsManager,
}
}
func (dm *DefaultManager) GetNetworkPeerByID(id string) (types.Peer, error) {
return dm.repository.FindPeerByID(id)
}
func (dm *DefaultManager) GetNetworkPeersInAccount(accountId string) ([]types.Peer, error) {
defaultPeers, err := dm.repository.FindAllPeersInAccount(accountId)
if err != nil {
return nil, err
}
peers := make([]types.Peer, len(defaultPeers))
for _, dp := range defaultPeers {
peers = append(peers, dp)
}
return peers, nil
}
func (dm *DefaultManager) GetPeerByPubKey(pubKey string) (types.Peer, error) {
return dm.repository.FindPeerByPubKey(pubKey)
}
func (dm *DefaultManager) GetPeerByID(id string) (types.Peer, error) {
return dm.repository.FindPeerByID(id)
}

View File

@@ -0,0 +1,10 @@
package peers
import "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
type Repository interface {
FindPeerByPubKey(pubKey string) (types.Peer, error)
FindPeerByID(id string) (types.Peer, error)
FindAllPeersInAccount(id string) ([]types.Peer, error)
UpdatePeer(peer types.Peer) error
}

View File

@@ -0,0 +1,261 @@
package types
import (
"fmt"
"net"
"time"
)
type Peer interface {
GetID() string
SetID(string)
GetAccountID() string
SetAccountID(string)
GetKey() string
SetKey(string)
GetSetupKey() string
SetSetupKey(string)
GetIP() net.IP
SetIP(net.IP)
GetName() string
SetName(string)
GetDNSLabel() string
SetDNSLabel(string)
GetUserID() string
SetUserID(string)
GetSSHKey() string
SetSSHKey(string)
GetSSHEnabled() bool
SetSSHEnabled(bool)
AddedWithSSOLogin() bool
UpdateMetaIfNew(meta PeerSystemMeta) bool
MarkLoginExpired(expired bool)
FQDN(dnsDomain string) string
EventMeta(dnsDomain string) map[string]any
LoginExpired(expiresIn time.Duration) (bool, time.Duration)
Copy() Peer
}
// Peer represents a machine connected to the network.
// The Peer is a WireGuard peer identified by a public key
type DefaultPeer struct {
// ID is an internal ID of the peer
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
// IP address of the Peer
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
// LastLogin the time when peer performed last login operation
LastLogin time.Time
// CreatedAt records the time the peer was created
CreatedAt time.Time
// Indicate ephemeral peer attribute
Ephemeral bool
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"`
}
// Location is a geo location information of a Peer based on public connection IP
type Location struct {
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup
CountryCode string
CityName string
GeoNameID uint // city level geoname id
}
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
type PeerLogin struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
// SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled)
SSHKey string
// Meta is the system information passed by peer, must be always present.
Meta PeerSystemMeta
// UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
UserID string
// AccountID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
AccountID string
// SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required.
SetupKey string
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *DefaultPeer) AddedWithSSOLogin() bool {
return p.UserID != ""
}
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *DefaultPeer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return false
}
p.Meta = meta
return true
}
// MarkLoginExpired marks peer's status expired or not
func (p *DefaultPeer) MarkLoginExpired(expired bool) {
newStatus := p.Status.Copy()
newStatus.LoginExpired = expired
if expired {
newStatus.Connected = false
}
p.Status = newStatus
}
// LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *DefaultPeer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
return false, 0
}
expiresAt := p.LastLogin.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
func (p *DefaultPeer) FQDN(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
// EventMeta returns activity event meta related to the peer
func (p *DefaultPeer) EventMeta(dnsDomain string) map[string]any {
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt}
}
func (p *DefaultPeer) GetID() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetID(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetAccountID() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetAccountID(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetKey() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetKey(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetSetupKey() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetSetupKey(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetIP() net.IP {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetIP(ip net.IP) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetName() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetName(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetDNSLabel() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetDNSLabel(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetUserID() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetUserID(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetSSHKey() string {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetSSHKey(s string) {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) GetSSHEnabled() bool {
// TODO implement me
panic("implement me")
}
func (p *DefaultPeer) SetSSHEnabled(b bool) {
// TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,24 @@
package types
import "time"
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
RequiresApproval: p.RequiresApproval,
}
}
type PeerStatus struct { //nolint:revive
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
LoginExpired bool
// RequiresApproval indicates whether peer requires approval or not
RequiresApproval bool
}

View File

@@ -0,0 +1,69 @@
package types
import "net/netip"
// NetworkAddress is the IP address with network and MAC address of a network interface
type NetworkAddress struct {
NetIP netip.Prefix `gorm:"serializer:json"`
Mac string
}
// Environment is a system environment information
type Environment struct {
Cloud string
Platform string
}
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct { //nolint:revive
Hostname string
GoOS string
Kernel string
Core string
Platform string
OS string
OSVersion string
WtVersion string
UIVersion string
KernelVersion string
NetworkAddresses []NetworkAddress `gorm:"serializer:json"`
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment `gorm:"serializer:json"`
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
if len(p.NetworkAddresses) != len(other.NetworkAddresses) {
return false
}
for _, addr := range p.NetworkAddresses {
var found bool
for _, oAddr := range other.NetworkAddresses {
if addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP {
found = true
continue
}
}
if !found {
return false
}
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform
}

View File

@@ -0,0 +1 @@
package policies

View File

@@ -0,0 +1,33 @@
package policies
import (
"github.com/netbirdio/netbird/management/refactor/resources/peers"
"github.com/netbirdio/netbird/management/refactor/resources/peers/types"
)
type Manager interface {
GetAccessiblePeersAndFirewallRules(peerID string) (peers []types.Peer, firewallRules []*FirewallRule)
}
type DefaultManager struct {
repository Repository
peerManager peers.Manager
}
func NewDefaultManager(repository Repository, peerManager peers.Manager) *DefaultManager {
return &DefaultManager{
repository: repository,
peerManager: peerManager,
}
}
func (dm *DefaultManager) GetAccessiblePeersAndFirewallRules(peerID string) (peers []types.Peer, firewallRules []*FirewallRule) {
peer, err := dm.peerManager.GetPeerByID(peerID)
if err != nil {
return nil, nil
}
peers, err = dm.peerManager.GetNetworkPeersInAccount(peer.GetAccountID())
return peers, nil
}

View File

@@ -0,0 +1 @@
package posture

View File

@@ -0,0 +1 @@
package posture

View File

@@ -0,0 +1 @@
package posture

View File

@@ -0,0 +1,4 @@
package policies
type Repository interface {
}

View File

@@ -0,0 +1,19 @@
package types
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
PeerIP string
// Direction of the traffic
Direction int
// Action of the traffic
Action string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port string
}

View File

@@ -0,0 +1,13 @@
package types
type Policy interface {
GetID() string
}
type DefaultPolicy struct {
ID string
}
func (dp *DefaultPolicy) GetID() string {
return dp.ID
}

View File

@@ -0,0 +1,13 @@
package types
type PolicyRule interface {
GetID() string
}
type DefaultPolicyRule struct {
ID string
}
func (dpr *DefaultPolicyRule) GetID() string {
return dpr.ID
}

View File

@@ -0,0 +1 @@
package routes

View File

@@ -0,0 +1,100 @@
package routes
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/refactor/resources/peers"
"github.com/netbirdio/netbird/management/refactor/resources/peers/types"
routeTypes "github.com/netbirdio/netbird/management/refactor/resources/routes/types"
"github.com/netbirdio/netbird/route"
)
type Manager interface {
GetRoutesToSync(peerID string, peersToConnect []*types.Peer, accountID string) []*routeTypes.Route
}
type DefaultManager struct {
repository Repository
peersManager peers.Manager
}
func NewDefaultManager(repository Repository, peersManager peers.Manager) *DefaultManager {
return &DefaultManager{
repository: repository,
peersManager: peersManager,
}
}
func (d DefaultManager) GetRoutesToSync(peerID string, peersToConnect []*types.Peer) []*routeTypes.Route {
routes, peerDisabledRoutes := d.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
}
groupListMap := a.getPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
func (d DefaultManager) getRoutingPeerRoutes(accountID, peerID string) (enabledRoutes []routeTypes.Route, disabledRoutes []routeTypes.Route) {
peer, err := d.peersManager.GetPeerByID(peerID)
if err != nil {
log.Errorf("peer %s that doesn't exist under account %s", peerID, accountID)
return nil, nil
}
// currently we support only linux routing peers
if peer.Meta.GoOS != "linux" {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[string]struct{})
takeRoute := func(r routeTypes.Route, id string) {
if _, ok := seenRoute[r.GetID()]; ok {
return
}
seenRoute[r.GetID()] = struct{}{}
if r.IsEnabled() {
r.SetPeer(peer.GetKey())
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
}
}
return enabledRoutes, disabledRoutes
}

View File

@@ -0,0 +1,4 @@
package routes
type Repository interface {
}

View File

@@ -0,0 +1,54 @@
package types
import "net/netip"
const (
// InvalidNetwork invalid network type
InvalidNetwork NetworkType = iota
// IPv4Network IPv4 network type
IPv4Network
// IPv6Network IPv6 network type
IPv6Network
)
// NetworkType route network type
type NetworkType int
type Route interface {
GetID() string
IsEnabled() bool
GetPeer() string
SetPeer(string)
}
type DefaultRoute struct {
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
Network netip.Prefix `gorm:"serializer:gob"`
NetID string
Description string
Peer string
PeerGroups []string `gorm:"serializer:gob"`
NetworkType NetworkType
Masquerade bool
Metric int
Enabled bool
Groups []string `gorm:"serializer:json"`
}
func (r *DefaultRoute) GetID() string {
return r.ID
}
func (r *DefaultRoute) IsEnabled() bool {
return r.Enabled
}
func (r *DefaultRoute) GetPeer() string {
return r.Peer
}
func (r *DefaultRoute) SetPeer(peer string) {
r.Peer = peer
}

View File

@@ -0,0 +1 @@
package settings

View File

@@ -0,0 +1,21 @@
package settings
import "github.com/netbirdio/netbird/management/refactor/resources/settings/types"
type Manager interface {
GetSettings(accountID string) (types.Settings, error)
}
type DefaultManager struct {
repository Repository
}
func NewDefaultManager(repository Repository) *DefaultManager {
return &DefaultManager{
repository: repository,
}
}
func (dm *DefaultManager) GetSettings(accountID string) (types.Settings, error) {
return dm.repository.FindSettings(accountID)
}

View File

@@ -0,0 +1,7 @@
package settings
import "github.com/netbirdio/netbird/management/refactor/resources/settings/types"
type Repository interface {
FindSettings(accountID string) (types.Settings, error)
}

View File

@@ -0,0 +1,34 @@
package types
import "time"
type Settings interface {
GetLicense() string
GetPeerLoginExpiration() time.Duration
SetPeerLoginExpiration(duration time.Duration)
GetPeerLoginExpirationEnabled() bool
SetPeerLoginExpirationEnabled(bool)
}
type DefaultSettings struct {
}
func (s *DefaultSettings) GetLicense() string {
return "selfhosted"
}
func (s *DefaultSettings) GetPeerLoginExpiration() time.Duration {
return 0
}
func (s *DefaultSettings) SetPeerLoginExpiration(duration time.Duration) {
}
func (s *DefaultSettings) GetPeerLoginExpirationEnabled() bool {
return false
}
func (s *DefaultSettings) SetPeerLoginExpirationEnabled(bool) {
}

View File

@@ -0,0 +1 @@
package setup_keys

View File

@@ -0,0 +1 @@
package setup_keys

View File

@@ -0,0 +1 @@
package setup_keys

View File

@@ -0,0 +1,7 @@
package types
type SetupKey interface {
}
type DefaultSetupKey struct {
}

View File

@@ -0,0 +1 @@
package users

View File

@@ -0,0 +1,27 @@
package users
import (
"github.com/netbirdio/netbird/management/refactor/resources/peers"
"github.com/netbirdio/netbird/management/refactor/resources/users/types"
)
type Manager interface {
GetUser(id string) (types.User, error)
}
type DefaultManager struct {
repository Repository
peerManager peers.Manager
}
func NewDefaultManager(repository Repository, peerManager peers.Manager) *DefaultManager {
return &DefaultManager{
repository: repository,
peerManager: peerManager,
}
}
func (d DefaultManager) GetUser(id string) (types.User, error) {
// TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1 @@
package personal_access_tokens

View File

@@ -0,0 +1 @@
package personal_access_tokens

View File

@@ -0,0 +1 @@
package personal_access_tokens

View File

@@ -0,0 +1,4 @@
package users
type Repository interface {
}

View File

@@ -0,0 +1,35 @@
package types
import "time"
// UserRole is the role of a User
type UserRole string
type User interface {
IsBlocked() bool
}
// User represents a user of the system
type DefaultUser struct {
Id string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Role UserRole
IsServiceUser bool
// NonDeletable indicates whether the service user can be deleted
NonDeletable bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
// Issued of the user
Issued string `gorm:"default:api"`
}
func (u *DefaultUser) IsBlocked() bool {
return u.Blocked
}

View File

@@ -0,0 +1,50 @@
package store
import (
"time"
dnsTypes "github.com/netbirdio/netbird/management/refactor/resources/dns/types"
groupTypes "github.com/netbirdio/netbird/management/refactor/resources/groups/types"
networkTypes "github.com/netbirdio/netbird/management/refactor/resources/network/types"
peerTypes "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
policyTypes "github.com/netbirdio/netbird/management/refactor/resources/policies/types"
routeTypes "github.com/netbirdio/netbird/management/refactor/resources/routes/types"
settingsTypes "github.com/netbirdio/netbird/management/refactor/resources/settings/types"
setupKeyTypes "github.com/netbirdio/netbird/management/refactor/resources/setup_keys/types"
userTypes "github.com/netbirdio/netbird/management/refactor/resources/users/types"
"github.com/netbirdio/netbird/management/server/posture"
)
// Account represents a unique account of the system
type DefaultAccount struct {
// we have to name column to aid as it collides with Network.Id when work with associations
Id string `gorm:"primaryKey"`
// User.Id it was created by
CreatedBy string
CreatedAt time.Time
Domain string `gorm:"index"`
DomainCategory string
IsDomainPrimaryAccount bool
SetupKeys map[string]*setupKeyTypes.DefaultSetupKey `gorm:"-"`
SetupKeysG []setupKeyTypes.DefaultSetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *networkTypes.Network `gorm:"embedded;embeddedPrefix:network_"`
Peers map[string]*peerTypes.DefaultPeer `gorm:"-"`
PeersG []peerTypes.DefaultPeer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*userTypes.DefaultUser `gorm:"-"`
UsersG []userTypes.DefaultUser `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*groupTypes.DefaultGroup `gorm:"-"`
GroupsG []groupTypes.DefaultGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*policyTypes.DefaultPolicy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*routeTypes.DefaultRoute `gorm:"-"`
RoutesG []routeTypes.DefaultRoute `json:"-" gorm:"foreignKey:AccountID;references:id"`
NameServerGroups map[string]*dnsTypes.DefaultNameServerGroup `gorm:"-"`
NameServerGroupsG []dnsTypes.DefaultNameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings dnsTypes.DefaultSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *settingsTypes.DefaultSettings `gorm:"embedded;embeddedPrefix:settings_"`
// deprecated on store and api level
Rules map[string]*Rule `json:"-" gorm:"-"`
RulesG []Rule `json:"-" gorm:"-"`
}

View File

@@ -0,0 +1,51 @@
package store
import (
peerTypes "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
settingsTypes "github.com/netbirdio/netbird/management/refactor/resources/settings/types"
)
const (
PostgresStoreEngine StoreEngine = "postgres"
)
type DefaultPostgresStore struct {
}
func (s *DefaultPostgresStore) FindSettings(accountID string) (settingsTypes.Settings, error) {
// TODO implement me
panic("implement me")
}
func (s *DefaultPostgresStore) FindPeerByPubKey(pubKey string) (peerTypes.Peer, error) {
// TODO implement me
panic("implement me")
}
func (s *DefaultPostgresStore) FindPeerByID(id string) (peerTypes.Peer, error) {
// TODO implement me
panic("implement me")
}
func (s *DefaultPostgresStore) FindAllPeersInAccount(id string) ([]peerTypes.Peer, error) {
// TODO implement me
panic("implement me")
}
func (s *DefaultPostgresStore) UpdatePeer(peer peerTypes.Peer) error {
// TODO implement me
panic("implement me")
}
func (s *DefaultPostgresStore) GetLicense() string {
// TODO implement me
panic("implement me")
}
func NewDefaultPostgresStore() *DefaultPostgresStore {
return &DefaultPostgresStore{}
}
func (s *DefaultPostgresStore) GetEngine() StoreEngine {
return PostgresStoreEngine
}

View File

@@ -0,0 +1,345 @@
package store
import (
"errors"
"path/filepath"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
dnsTypes "github.com/netbirdio/netbird/management/refactor/resources/dns/types"
groupTypes "github.com/netbirdio/netbird/management/refactor/resources/groups/types"
"github.com/netbirdio/netbird/management/refactor/resources/peers"
policyTypes "github.com/netbirdio/netbird/management/refactor/resources/policies/types"
routeTypes "github.com/netbirdio/netbird/management/refactor/resources/routes/types"
"github.com/netbirdio/netbird/management/refactor/resources/settings"
setupKeyTypes "github.com/netbirdio/netbird/management/refactor/resources/setup_keys/types"
userTypes "github.com/netbirdio/netbird/management/refactor/resources/users/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
SqliteStoreEngine StoreEngine = "sqlite"
)
// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk
type DefaultSqliteStore struct {
DB *gorm.DB
storeFile string
accountLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
accounts map[string]*DefaultAccount
}
type installation struct {
ID uint `gorm:"primaryKey"`
InstallationIDValue string
}
// NewSqliteStore restores a store from the file located in the datadir
func NewDefaultSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*DefaultSqliteStore, error) {
storeStr := "store.DB?cache=shared"
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = "store.DB"
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
if err != nil {
return nil, err
}
sql, err := db.DB()
if err != nil {
return nil, err
}
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable
// err = DB.AutoMigrate(
// &SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{},
// &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
// &installation{},
// )
// if err != nil {
// return nil, err
// }
return &DefaultSqliteStore{DB: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *DefaultSqliteStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.Debugf("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
return unlock
}
func (s *DefaultSqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *DefaultSqliteStore) LoadAccount(accountID string) error {
var account DefaultAccount
result := s.DB.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, "id = ?", accountID)
if result.Error != nil {
log.Errorf("error when getting account from the store: %s", result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "account not found")
}
return status.Errorf(status.Internal, "issue getting account from store")
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*policyTypes.DefaultPolicyRule
err := s.DB.Model(&policyTypes.DefaultPolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*setupKeyTypes.DefaultSetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*userTypes.DefaultUser, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*groupTypes.DefaultGroup, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
account.Routes = make(map[string]*routeTypes.DefaultRoute, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*dnsTypes.DefaultNameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
s.accounts[account.Id] = &account
return nil
}
func (s *DefaultSqliteStore) WriteAccount(accountID string) error {
start := time.Now()
account, ok := s.accounts[accountID]
if !ok {
return status.Errorf(status.NotFound, "account not found")
}
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
}
for id, peer := range account.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
}
for id, user := range account.Users {
user.Id = id
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
}
for id, route := range account.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
}
err := s.DB.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
return err
}
func (s *DefaultSqliteStore) SaveInstallationID(ID string) error {
installation := installation{InstallationIDValue: ID}
installation.ID = uint(s.installationPK)
return s.DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
}
func (s *DefaultSqliteStore) GetInstallationID() string {
var installation installation
if result := s.DB.First(&installation, "id = ?", s.installationPK); result.Error != nil {
return ""
}
return installation.InstallationIDValue
}
// Close is noop in Sqlite
func (s *DefaultSqliteStore) Close() error {
return nil
}
// GetStoreEngine returns SqliteStoreEngine
func (s *DefaultSqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine
}
func (s *DefaultSqliteStore) GetLicense() string {
// TODO implement me
panic("implement me")
}
func (s *DefaultSqliteStore) FindSettings(accountID string) (settings.Settings, error) {
account, ok := s.accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
}
return account.Settings, nil
}
func (s *DefaultSqliteStore) FindPeerByPubKey(accountID string, pubKey string) (peers.Peer, error) {
a, ok := s.accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
}
for _, peer := range a.Peers {
if peer.Key == pubKey {
return peer.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", pubKey)
}
func (s *DefaultSqliteStore) FindPeerByID(accountID string, id string) (peers.Peer, error) {
a, ok := s.accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
}
for _, peer := range a.Peers {
if peer.ID == id {
return peer.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "peer with the ID %s not found", id)
}
func (s *DefaultSqliteStore) FindAllPeersInAccount(accountId string) ([]peers.Peer, error) {
a, ok := s.accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
}
return a.Peers, nil
}
func (s *DefaultSqliteStore) UpdatePeer(peer peers.Peer) error {
// TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,65 @@
package store
import (
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
peerTypes "github.com/netbirdio/netbird/management/refactor/resources/peers/types"
settingsTypes "github.com/netbirdio/netbird/management/refactor/resources/settings/types"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type Store interface {
AcquireAccountLock(id string) func()
AcquireGlobalLock() func()
LoadAccount(id string) error
WriteAccount(id string) error
GetLicense() string
FindPeerByPubKey(pubKey string) (peerTypes.Peer, error)
FindPeerByID(id string) (peerTypes.Peer, error)
FindAllPeersInAccount(id string) ([]peerTypes.Peer, error)
UpdatePeer(peer peerTypes.Peer) error
FindSettings(accountID string) (settingsTypes.Settings, error)
}
type DefaultStore interface {
Store
}
type StoreEngine string
func getStoreEngineFromEnv() StoreEngine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok {
return SqliteStoreEngine
}
value := StoreEngine(strings.ToLower(kind))
if value == PostgresStoreEngine || value == SqliteStoreEngine {
return value
}
return SqliteStoreEngine
}
func NewDefaultStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (DefaultStore, error) {
if kind == "" {
// fallback to env. Normally this only should be used from tests
kind = getStoreEngineFromEnv()
}
switch kind {
case PostgresStoreEngine:
log.Info("using JSON file store engine")
return NewDefaultPostgresStore(), nil
case SqliteStoreEngine:
log.Info("using SQLite store engine")
return NewDefaultSqliteStore(dataDir, metrics)
default:
return nil, fmt.Errorf("unsupported kind of store %s", kind)
}
}