mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 18:12:30 -04:00
Compare commits
11 Commits
prototype/
...
v0.66.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05b66e73bc | ||
|
|
01ceedac89 | ||
|
|
403babd433 | ||
|
|
47133031e5 | ||
|
|
82da606886 | ||
|
|
bbe5ae2145 | ||
|
|
0b21498b39 | ||
|
|
0ca59535f1 | ||
|
|
59c77d0658 | ||
|
|
333e045099 | ||
|
|
c2c4d9d336 |
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.2
|
||||
FROM alpine:3.23.3
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -331,8 +331,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
@@ -49,7 +49,7 @@ func ResolveUnixDaemonAddr(addr string) string {
|
||||
switch len(found) {
|
||||
case 1:
|
||||
resolved := "unix://" + found[0]
|
||||
log.Infof("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||
return resolved
|
||||
case 0:
|
||||
return addr
|
||||
|
||||
@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
|
||||
configDir := filepath.Join(DefaultConfigPathDir, username)
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(configDir, 0600); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
func fileExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
return !os.IsNotExist(err)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err := util.EnforcePermission(input.ConfigPath)
|
||||
err = util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
if fileExists(configPath) {
|
||||
configExists, err := fileExists(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
|
||||
if configExists {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
|
||||
@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
if fileExists(profPath) {
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
return ErrProfileAlreadyExists
|
||||
}
|
||||
|
||||
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
if !fileExists(profPath) {
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if !profileExists {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
if !fileExists(stateFile) {
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
}
|
||||
if !stateFileExists {
|
||||
return nil, errors.New("profile state file does not exist")
|
||||
}
|
||||
|
||||
|
||||
@@ -263,8 +263,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
|
||||
case <-closer:
|
||||
return
|
||||
case routerStates := <-subscription.Events():
|
||||
peerStateUpdate <- routerStates
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
select {
|
||||
case peerStateUpdate <- routerStates:
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-closer:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,8 +641,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
@@ -654,10 +652,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
// not in the progress or already successfully established connection.
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("up already in progress: current status %s", status)
|
||||
}
|
||||
|
||||
@@ -674,17 +674,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.actCancel = cancel
|
||||
|
||||
if s.config == nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
@@ -692,6 +695,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
activeProf, err = s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
@@ -700,6 +704,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||
}
|
||||
@@ -718,6 +723,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
}
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -71,6 +72,7 @@ type ServerConfig struct {
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Store StoreConfig `yaml:"store"`
|
||||
ActivityStore StoreConfig `yaml:"activityStore"`
|
||||
AuthStore StoreConfig `yaml:"authStore"`
|
||||
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||
}
|
||||
|
||||
@@ -171,7 +173,8 @@ type RelaysConfig struct {
|
||||
type StoreConfig struct {
|
||||
Engine string `yaml:"engine"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
|
||||
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
|
||||
}
|
||||
|
||||
// ReverseProxyConfig contains reverse proxy settings
|
||||
@@ -533,6 +536,74 @@ func stripSignalProtocol(uri string) string {
|
||||
return uri
|
||||
}
|
||||
|
||||
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
|
||||
var ttl time.Duration
|
||||
if relays.CredentialsTTL != "" {
|
||||
var err error
|
||||
ttl, err = time.ParseDuration(relays.CredentialsTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
|
||||
}
|
||||
}
|
||||
return &nbconfig.Relay{
|
||||
Addresses: relays.Addresses,
|
||||
CredentialsTTL: util.Duration{Duration: ttl},
|
||||
Secret: relays.Secret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
|
||||
// authStore overrides auth.storage when set.
|
||||
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
|
||||
authStorageType := mgmt.Auth.Storage.Type
|
||||
authStorageDSN := c.Server.AuthStore.DSN
|
||||
if c.Server.AuthStore.Engine != "" {
|
||||
authStorageType = c.Server.AuthStore.Engine
|
||||
}
|
||||
if authStorageType == "" {
|
||||
authStorageType = "sqlite3"
|
||||
}
|
||||
authStorageFile := ""
|
||||
if authStorageType == "postgres" {
|
||||
if authStorageDSN == "" {
|
||||
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
|
||||
}
|
||||
} else {
|
||||
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||
if c.Server.AuthStore.File != "" {
|
||||
authStorageFile = c.Server.AuthStore.File
|
||||
if !filepath.IsAbs(authStorageFile) {
|
||||
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: authStorageType,
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: authStorageFile,
|
||||
DSN: authStorageDSN,
|
||||
},
|
||||
},
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
}
|
||||
|
||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||
cfg.Owner = &idp.OwnerConfig{
|
||||
Email: mgmt.Auth.Owner.Email,
|
||||
Hash: mgmt.Auth.Owner.Password,
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ToManagementConfig converts CombinedConfig to management server config
|
||||
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
mgmt := c.Management
|
||||
@@ -551,19 +622,11 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
// Build relay config
|
||||
var relayConfig *nbconfig.Relay
|
||||
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
|
||||
var ttl time.Duration
|
||||
if mgmt.Relays.CredentialsTTL != "" {
|
||||
var err error
|
||||
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
|
||||
}
|
||||
}
|
||||
relayConfig = &nbconfig.Relay{
|
||||
Addresses: mgmt.Relays.Addresses,
|
||||
CredentialsTTL: util.Duration{Duration: ttl},
|
||||
Secret: mgmt.Relays.Secret,
|
||||
relay, err := buildRelayConfig(mgmt.Relays)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relayConfig = relay
|
||||
}
|
||||
|
||||
// Build signal config
|
||||
@@ -599,31 +662,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
httpConfig := &nbconfig.HttpServerConfig{}
|
||||
|
||||
// Build embedded IDP config (always enabled in combined server)
|
||||
storageFile := mgmt.Auth.Storage.File
|
||||
if storageFile == "" {
|
||||
storageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||
}
|
||||
|
||||
embeddedIdP := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: mgmt.Auth.Issuer,
|
||||
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
|
||||
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: mgmt.Auth.Storage.Type,
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: storageFile,
|
||||
},
|
||||
},
|
||||
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
|
||||
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
|
||||
}
|
||||
|
||||
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
|
||||
embeddedIdP.Owner = &idp.OwnerConfig{
|
||||
Email: mgmt.Auth.Owner.Email,
|
||||
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
|
||||
}
|
||||
embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set HTTP config fields for embedded IDP
|
||||
|
||||
@@ -140,6 +140,9 @@ func initializeConfig() error {
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := config.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
if engine := config.Server.ActivityStore.Engine; engine != "" {
|
||||
engineLower := strings.ToLower(engine)
|
||||
@@ -151,6 +154,9 @@ func initializeConfig() error {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := config.Server.ActivityStore.File; file != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
log.Infof("Starting combined NetBird server")
|
||||
logConfig(config)
|
||||
|
||||
@@ -42,6 +42,9 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
@@ -103,11 +103,19 @@ server:
|
||||
engine: "sqlite" # sqlite, postgres, or mysql
|
||||
dsn: "" # Connection string for postgres or mysql
|
||||
encryptionKey: ""
|
||||
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
|
||||
|
||||
# Activity events store configuration (optional, defaults to sqlite in dataDir)
|
||||
# activityStore:
|
||||
# engine: "sqlite" # sqlite or postgres
|
||||
# dsn: "" # Connection string for postgres
|
||||
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
|
||||
|
||||
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
|
||||
# authStore:
|
||||
# engine: "sqlite3" # sqlite3 or postgres
|
||||
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
|
||||
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -195,11 +198,175 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||
}
|
||||
return (&sql.SQLite3{File: file}).Open(logger)
|
||||
case "postgres":
|
||||
dsn, _ := s.Config["dsn"].(string)
|
||||
if dsn == "" {
|
||||
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
|
||||
}
|
||||
pg, err := parsePostgresDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
|
||||
}
|
||||
return pg.Open(logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// parsePostgresDSN parses a DSN into a sql.Postgres config.
|
||||
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
|
||||
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
|
||||
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
|
||||
var params map[string]string
|
||||
var err error
|
||||
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
params, err = parsePostgresURI(dsn)
|
||||
} else {
|
||||
params, err = parsePostgresKeyValue(dsn)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host := params["host"]
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
var port uint16 = 5432
|
||||
if p, ok := params["port"]; ok && p != "" {
|
||||
v, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port %q: %w", p, err)
|
||||
}
|
||||
if v == 0 {
|
||||
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
|
||||
}
|
||||
port = uint16(v)
|
||||
}
|
||||
|
||||
dbname := params["dbname"]
|
||||
if dbname == "" {
|
||||
return nil, fmt.Errorf("dbname is required in DSN")
|
||||
}
|
||||
|
||||
pg := &sql.Postgres{
|
||||
NetworkDB: sql.NetworkDB{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Database: dbname,
|
||||
User: params["user"],
|
||||
Password: params["password"],
|
||||
},
|
||||
}
|
||||
|
||||
if sslMode := params["sslmode"]; sslMode != "" {
|
||||
switch sslMode {
|
||||
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
|
||||
pg.SSL.Mode = sslMode
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
|
||||
}
|
||||
}
|
||||
|
||||
return pg, nil
|
||||
}
|
||||
|
||||
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
|
||||
func parsePostgresURI(dsn string) (map[string]string, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid postgres URI: %w", err)
|
||||
}
|
||||
|
||||
params := make(map[string]string)
|
||||
|
||||
if u.User != nil {
|
||||
params["user"] = u.User.Username()
|
||||
if p, ok := u.User.Password(); ok {
|
||||
params["password"] = p
|
||||
}
|
||||
}
|
||||
if u.Hostname() != "" {
|
||||
params["host"] = u.Hostname()
|
||||
}
|
||||
if u.Port() != "" {
|
||||
params["port"] = u.Port()
|
||||
}
|
||||
|
||||
dbname := strings.TrimPrefix(u.Path, "/")
|
||||
if dbname != "" {
|
||||
params["dbname"] = dbname
|
||||
}
|
||||
|
||||
for k, v := range u.Query() {
|
||||
if len(v) > 0 {
|
||||
params[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
|
||||
// (e.g., password='my pass' host=localhost).
|
||||
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
|
||||
params := make(map[string]string)
|
||||
s := strings.TrimSpace(dsn)
|
||||
|
||||
for s != "" {
|
||||
eqIdx := strings.IndexByte(s, '=')
|
||||
if eqIdx < 0 {
|
||||
break
|
||||
}
|
||||
key := strings.TrimSpace(s[:eqIdx])
|
||||
|
||||
value, rest, err := parseDSNValue(s[eqIdx+1:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w for key %q", err, key)
|
||||
}
|
||||
|
||||
params[key] = value
|
||||
s = strings.TrimSpace(rest)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
|
||||
// It returns the parsed value and the remaining unparsed string.
|
||||
func parseDSNValue(s string) (value, rest string, err error) {
|
||||
if len(s) > 0 && s[0] == '\'' {
|
||||
return parseQuotedDSNValue(s[1:])
|
||||
}
|
||||
// Unquoted value: read until whitespace.
|
||||
idx := strings.IndexAny(s, " \t\n")
|
||||
if idx < 0 {
|
||||
return s, "", nil
|
||||
}
|
||||
return s[:idx], s[idx:], nil
|
||||
}
|
||||
|
||||
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
|
||||
// Libpq uses ” to represent a literal single quote inside quoted values.
|
||||
func parseQuotedDSNValue(s string) (value, rest string, err error) {
|
||||
var buf strings.Builder
|
||||
for len(s) > 0 {
|
||||
if s[0] == '\'' {
|
||||
if len(s) > 1 && s[1] == '\'' {
|
||||
buf.WriteByte('\'')
|
||||
s = s[2:]
|
||||
continue
|
||||
}
|
||||
return buf.String(), s[1:], nil
|
||||
}
|
||||
buf.WriteByte(s[0])
|
||||
s = s[1:]
|
||||
}
|
||||
return "", "", fmt.Errorf("unterminated quoted value")
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *YAMLConfig) Validate() error {
|
||||
if c.Issuer == "" {
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"slices"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
@@ -410,12 +410,15 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
|
||||
var service *reverseproxy.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,11 +13,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -1112,3 +1115,67 @@ func TestGetGroupIDsFromNames(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "no group names provided")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
userID := "test-user"
|
||||
|
||||
sqlStore, err := store.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPerms := permissions.NewMockManager(ctrl)
|
||||
mockAcct := account.NewMockManager(ctrl)
|
||||
mockGRPC := &nbgrpc.ProxyServiceServer{}
|
||||
|
||||
mgr := &managerImpl{
|
||||
store: sqlStore,
|
||||
permissionsManager: mockPerms,
|
||||
accountManager: mockAcct,
|
||||
proxyGRPCServer: mockGRPC,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-1",
|
||||
AccountID: accountID,
|
||||
Domain: "test.example.com",
|
||||
ProxyCluster: "cluster1",
|
||||
Enabled: true,
|
||||
Targets: []*reverseproxy.Target{
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-1"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-2"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-3"},
|
||||
},
|
||||
}
|
||||
|
||||
err = sqlStore.CreateService(ctx, service)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrievedService, err := sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
|
||||
|
||||
mockPerms.EXPECT().
|
||||
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
|
||||
Return(true, nil)
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.NotFound, s.Type())
|
||||
|
||||
targets, err := sqlStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, accountID, service.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package account
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
@@ -61,11 +63,11 @@ type Manager interface {
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
|
||||
1738
management/server/account/manager_mock.go
Normal file
1738
management/server/account/manager_mock.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
|
||||
|
||||
switch storeEngine {
|
||||
case types.SqliteStoreEngine:
|
||||
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
|
||||
dbFile := eventSinkDB
|
||||
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
|
||||
dbFile = envFile
|
||||
}
|
||||
connStr := dbFile
|
||||
if !filepath.IsAbs(dbFile) {
|
||||
connStr = filepath.Join(dataDir, dbFile)
|
||||
}
|
||||
dialector = sqlite.Open(connStr)
|
||||
case types.PostgresStoreEngine:
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
|
||||
@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
|
||||
|
||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||
type EmbeddedStorageConfig struct {
|
||||
// Type is the storage type (currently only "sqlite3" is supported)
|
||||
// Type is the storage type: "sqlite3" (default) or "postgres"
|
||||
Type string
|
||||
// Config contains type-specific configuration
|
||||
Config EmbeddedStorageTypeConfig
|
||||
@@ -62,6 +62,8 @@ type EmbeddedStorageConfig struct {
|
||||
type EmbeddedStorageTypeConfig struct {
|
||||
// File is the path to the SQLite database file (for sqlite3 type)
|
||||
File string
|
||||
// DSN is the connection string for postgres
|
||||
DSN string
|
||||
}
|
||||
|
||||
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
|
||||
@@ -74,6 +76,22 @@ type OwnerConfig struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
|
||||
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
|
||||
switch storageType {
|
||||
case "sqlite3":
|
||||
return map[string]interface{}{
|
||||
"file": cfg.File,
|
||||
}, nil
|
||||
case "postgres":
|
||||
return map[string]interface{}{
|
||||
"dsn": cfg.DSN,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
|
||||
}
|
||||
}
|
||||
|
||||
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
|
||||
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
if c.Issuer == "" {
|
||||
@@ -85,6 +103,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
|
||||
return nil, fmt.Errorf("storage file is required for sqlite3")
|
||||
}
|
||||
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
|
||||
return nil, fmt.Errorf("storage DSN is required for postgres")
|
||||
}
|
||||
|
||||
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
@@ -100,10 +126,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
Type: c.Storage.Type,
|
||||
Config: map[string]interface{}{
|
||||
"file": c.Storage.Config.File,
|
||||
},
|
||||
Type: c.Storage.Type,
|
||||
Config: storageConfig,
|
||||
},
|
||||
Web: dex.Web{
|
||||
AllowedOrigins: []string{"*"},
|
||||
|
||||
@@ -2728,14 +2728,28 @@ func (s *SqlStore) GetStoreEngine() types.Engine {
|
||||
|
||||
// NewSqliteStore creates a new SQLite store.
|
||||
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
||||
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = storeSqliteFileName
|
||||
storeFile := storeSqliteFileName
|
||||
if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" {
|
||||
storeFile = envFile
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
// Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc")
|
||||
filePath, query, hasQuery := strings.Cut(storeFile, "?")
|
||||
|
||||
connStr := filePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
connStr = filepath.Join(dataDir, filePath)
|
||||
}
|
||||
|
||||
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
|
||||
if hasQuery {
|
||||
connStr += "?" + query
|
||||
} else if runtime.GOOS != "windows" {
|
||||
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
connStr += "?cache=shared"
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4895,6 +4909,46 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete target from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete targets from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID retrieves all targets for a given service
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
var targets []*reverseproxy.Target
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get targets from store")
|
||||
}
|
||||
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
@@ -272,6 +272,9 @@ type Store interface {
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error)
|
||||
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
|
||||
@@ -559,6 +559,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// DeleteServiceTargets mocks base method.
|
||||
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
|
||||
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// DeleteSetupKey mocks base method.
|
||||
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -573,6 +587,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
|
||||
}
|
||||
|
||||
// DeleteTarget mocks base method.
|
||||
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteTarget indicates an expected call of DeleteTarget.
|
||||
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex mocks base method.
|
||||
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1109,21 +1137,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSettings mocks base method.
|
||||
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1288,6 +1301,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts mocks base method.
|
||||
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
|
||||
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
|
||||
}
|
||||
|
||||
// GetDNSRecordByID mocks base method.
|
||||
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1872,22 +1901,6 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTargetByTargetID", reflect.TypeOf((*MockStore)(nil).GetServiceTargetByTargetID), ctx, lockStrength, accountID, targetID)
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts mocks base method.
|
||||
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
|
||||
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
|
||||
}
|
||||
|
||||
// GetServices mocks base method.
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1903,6 +1916,21 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID mocks base method.
|
||||
func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1962,6 +1990,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID mocks base method.
|
||||
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
|
||||
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
|
||||
}
|
||||
|
||||
// GetTokenIDByHashedToken mocks base method.
|
||||
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -81,9 +81,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
||||
Transport: p.transport,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
||||
Transport: p.transport,
|
||||
FlushInterval: -1,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
}
|
||||
if result.rewriteRedirects {
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
|
||||
|
||||
@@ -11,6 +11,26 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// APIError represents an error response from the management API.
|
||||
type APIError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *APIError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// IsNotFound returns true if the error represents a 404 Not Found response.
|
||||
func IsNotFound(err error) bool {
|
||||
var apiErr *APIError
|
||||
if ok := errors.As(err, &apiErr); ok {
|
||||
return apiErr.StatusCode == http.StatusNotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Client Management service HTTP REST API Client
|
||||
type Client struct {
|
||||
managementURL string
|
||||
@@ -105,6 +125,15 @@ type Client struct {
|
||||
// Instance NetBird Instance API
|
||||
// see more: https://docs.netbird.io/api/resources/instance
|
||||
Instance *InstanceAPI
|
||||
|
||||
// ReverseProxyServices NetBird reverse proxy services APIs
|
||||
ReverseProxyServices *ReverseProxyServicesAPI
|
||||
|
||||
// ReverseProxyClusters NetBird reverse proxy clusters APIs
|
||||
ReverseProxyClusters *ReverseProxyClustersAPI
|
||||
|
||||
// ReverseProxyDomains NetBird reverse proxy domains APIs
|
||||
ReverseProxyDomains *ReverseProxyDomainsAPI
|
||||
}
|
||||
|
||||
// New initialize new Client instance using PAT token
|
||||
@@ -160,6 +189,9 @@ func (c *Client) initialize() {
|
||||
c.IdentityProviders = &IdentityProvidersAPI{c}
|
||||
c.Ingress = &IngressAPI{c}
|
||||
c.Instance = &InstanceAPI{c}
|
||||
c.ReverseProxyServices = &ReverseProxyServicesAPI{c}
|
||||
c.ReverseProxyClusters = &ReverseProxyClustersAPI{c}
|
||||
c.ReverseProxyDomains = &ReverseProxyDomainsAPI{c}
|
||||
}
|
||||
|
||||
// NewRequest creates and executes new management API request
|
||||
@@ -194,10 +226,12 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re
|
||||
if resp.StatusCode > 299 {
|
||||
parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
|
||||
if pErr != nil {
|
||||
|
||||
return nil, pErr
|
||||
}
|
||||
return nil, errors.New(parsedErr.Message)
|
||||
return nil, &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Message: parsedErr.Message,
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
||||
25
shared/management/client/rest/reverse_proxy_clusters.go
Normal file
25
shared/management/client/rest/reverse_proxy_clusters.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// ReverseProxyClustersAPI APIs for Reverse Proxy Clusters, do not use directly
|
||||
type ReverseProxyClustersAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List lists all available proxy clusters
|
||||
func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/clusters", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.ProxyCluster](resp)
|
||||
return ret, err
|
||||
}
|
||||
72
shared/management/client/rest/reverse_proxy_domains.go
Normal file
72
shared/management/client/rest/reverse_proxy_domains.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// ReverseProxyDomainsAPI APIs for Reverse Proxy Domains, do not use directly
|
||||
type ReverseProxyDomainsAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List lists all reverse proxy domains
|
||||
func (a *ReverseProxyDomainsAPI) List(ctx context.Context) ([]api.ReverseProxyDomain, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/domains", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.ReverseProxyDomain](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Create creates a new custom domain
|
||||
func (a *ReverseProxyDomainsAPI) Create(ctx context.Context, request api.PostApiReverseProxiesDomainsJSONRequestBody) (*api.ReverseProxyDomain, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/domains", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.ReverseProxyDomain](resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
// Delete deletes a custom domain
|
||||
func (a *ReverseProxyDomainsAPI) Delete(ctx context.Context, domainID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/domains/"+url.PathEscape(domainID), nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate triggers domain ownership validation for a custom domain
|
||||
func (a *ReverseProxyDomainsAPI) Validate(ctx context.Context, domainID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/domains/"+url.PathEscape(domainID)+"/validate", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
97
shared/management/client/rest/reverse_proxy_services.go
Normal file
97
shared/management/client/rest/reverse_proxy_services.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// ReverseProxyServicesAPI APIs for Reverse Proxy Services, do not use directly
|
||||
type ReverseProxyServicesAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List lists all reverse proxy services
|
||||
func (a *ReverseProxyServicesAPI) List(ctx context.Context) ([]api.Service, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/services", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Service](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get retrieves a reverse proxy service by ID
|
||||
func (a *ReverseProxyServicesAPI) Get(ctx context.Context, serviceID string) (*api.Service, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Service](resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
// Create creates a new reverse proxy service
|
||||
func (a *ReverseProxyServicesAPI) Create(ctx context.Context, request api.PostApiReverseProxiesServicesJSONRequestBody) (*api.Service, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/services", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Service](resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
// Update updates a reverse proxy service
|
||||
func (a *ReverseProxyServicesAPI) Update(ctx context.Context, serviceID string, request api.PutApiReverseProxiesServicesServiceIdJSONRequestBody) (*api.Service, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Service](resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
// Delete deletes a reverse proxy service
|
||||
func (a *ReverseProxyServicesAPI) Delete(ctx context.Context, serviceID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user