diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 8af4046a7..8e672043d 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -63,10 +63,15 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV + - name: Generate test script + run: | + $packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } + $goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe" + $cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1" + Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1a4676625..83444b541 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -170,6 +170,7 @@ jobs: run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu - name: Decode GPG signing key + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository env: GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} run: | @@ -309,6 +310,7 @@ jobs: run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 - name: Decode GPG signing key + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository env: GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }} run: | diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml index 47e45165b..81ae36e78 100644 --- a/.github/workflows/wasm-build-validation.yml +++ b/.github/workflows/wasm-build-validation.yml @@ -61,8 +61,8 @@ jobs: echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" - if [ ${SIZE} -gt 57671680 ]; then - echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!" + if [ ${SIZE} -gt 58720256 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!" exit 1 fi diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 0f81229cd..65e63dfa8 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -171,6 +171,7 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ + license: BSD-3-Clause id: netbird_deb bindir: /usr/bin builds: @@ -184,6 +185,7 @@ nfpms: - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ + license: BSD-3-Clause id: netbird_rpm bindir: /usr/bin builds: diff --git a/client/Dockerfile b/client/Dockerfile index 13e44096f..64d5ba04f 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -17,8 +17,7 @@ ENV \ NETBIRD_BIN="/usr/local/bin/netbird" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ - NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="5" + NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 5fa8de0a5..69d00aaf2 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -23,8 +23,7 @@ ENV \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_DISABLE_DNS="true" \ - NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/cmd/debug.go b/client/cmd/debug.go index e480df4d7..0e2717756 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { if stateWasDown { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { - return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird up") + time.Sleep(time.Second * 10) } - cmd.Println("netbird up") - time.Sleep(time.Second * 10) } initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE @@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error { } if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { - return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird down") } - cmd.Println("netbird down") time.Sleep(1 * time.Second) @@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error { if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { - return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message()) } if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { - return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird up") } - cmd.Println("netbird up") time.Sleep(3 * time.Second) @@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error { if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { - return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird down") } - cmd.Println("netbird down") } if !initialLevelTrace { if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil { - return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message()) + cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("Log level restored to", initialLogLevel.GetLevel()) } - cmd.Println("Log level restored to", initialLogLevel.GetLevel()) } cmd.Printf("Local file:\n%s\n", resp.GetPath()) diff --git a/client/cmd/expose.go b/client/cmd/expose.go index 1334617d8..f4727703e 100644 --- a/client/cmd/expose.go +++ b/client/cmd/expose.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/netbirdio/netbird/client/internal/expose" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/util" ) @@ -211,19 +212,24 @@ func exposeFn(cmd *cobra.Command, args []string) error { } func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) { - switch strings.ToLower(exposeProtocol) { - case "http": + p, err := expose.ParseProtocolType(exposeProtocol) + if err != nil { + return 0, fmt.Errorf("invalid protocol: %w", err) + } + + switch p { + case expose.ProtocolHTTP: return proto.ExposeProtocol_EXPOSE_HTTP, nil - case "https": + case expose.ProtocolHTTPS: return proto.ExposeProtocol_EXPOSE_HTTPS, nil - case "tcp": + case expose.ProtocolTCP: return proto.ExposeProtocol_EXPOSE_TCP, nil - case "udp": + case expose.ProtocolUDP: return proto.ExposeProtocol_EXPOSE_UDP, nil - case "tls": + case expose.ProtocolTLS: return proto.ExposeProtocol_EXPOSE_TLS, nil default: - return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol) + return 0, fmt.Errorf("unhandled protocol type: %d", p) } } diff --git a/client/cmd/service.go b/client/cmd/service.go index e55465875..5ff16eaeb 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -41,7 +41,7 @@ func init() { defaultServiceName = "Netbird" } - serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) + serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0545ce6b7..5fe318ddf 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error { // Common setup for service control commands func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) { - SetFlagsFromEnvVars(rootCmd) + // rootCmd env vars are already applied by PersistentPreRunE. SetFlagsFromEnvVars(serviceCmd) cmd.SetOut(cmd.OutOrStdout()) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index f6828d96a..28770ea16 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -119,6 +119,10 @@ var installCmd = &cobra.Command{ return err } + if err := loadAndApplyServiceParams(cmd); err != nil { + cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err) + } + svcConfig, err := createServiceConfigForInstall() if err != nil { return err @@ -136,6 +140,10 @@ var installCmd = &cobra.Command{ return fmt.Errorf("install service: %w", err) } + if err := saveServiceParams(currentServiceParams()); err != nil { + cmd.PrintErrf("Warning: failed to save service params: %v\n", err) + } + cmd.Println("NetBird service has been installed") return nil }, @@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re return err } + if err := loadAndApplyServiceParams(cmd); err != nil { + cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err) + } + wasRunning, err := isServiceRunning() if err != nil && !errors.Is(err, ErrGetServiceStatus) { return fmt.Errorf("check service status: %w", err) @@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re return fmt.Errorf("install service with new config: %w", err) } + if err := saveServiceParams(currentServiceParams()); err != nil { + cmd.PrintErrf("Warning: failed to save service params: %v\n", err) + } + if wasRunning { cmd.Println("Starting NetBird service...") if err := s.Start(); err != nil { diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go new file mode 100644 index 000000000..81bd2dbb5 --- /dev/null +++ b/client/cmd/service_params.go @@ -0,0 +1,201 @@ +//go:build !ios && !android + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "os" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/configs" + "github.com/netbirdio/netbird/util" +) + +const serviceParamsFile = "service.json" + +// serviceParams holds install-time service parameters that persist across +// uninstall/reinstall cycles. Saved to /service.json. +type serviceParams struct { + LogLevel string `json:"log_level"` + DaemonAddr string `json:"daemon_addr"` + ManagementURL string `json:"management_url,omitempty"` + ConfigPath string `json:"config_path,omitempty"` + LogFiles []string `json:"log_files,omitempty"` + DisableProfiles bool `json:"disable_profiles,omitempty"` + DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` +} + +// serviceParamsPath returns the path to the service params file. +func serviceParamsPath() string { + return filepath.Join(configs.StateDir, serviceParamsFile) +} + +// loadServiceParams reads saved service parameters from disk. +// Returns nil with no error if the file does not exist. +func loadServiceParams() (*serviceParams, error) { + path := serviceParamsPath() + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil //nolint:nilnil + } + return nil, fmt.Errorf("read service params %s: %w", path, err) + } + + var params serviceParams + if err := json.Unmarshal(data, ¶ms); err != nil { + return nil, fmt.Errorf("parse service params %s: %w", path, err) + } + + return ¶ms, nil +} + +// saveServiceParams writes current service parameters to disk atomically +// with restricted permissions. +func saveServiceParams(params *serviceParams) error { + path := serviceParamsPath() + if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil { + return fmt.Errorf("save service params: %w", err) + } + return nil +} + +// currentServiceParams captures the current state of all package-level +// variables into a serviceParams struct. +func currentServiceParams() *serviceParams { + params := &serviceParams{ + LogLevel: logLevel, + DaemonAddr: daemonAddr, + ManagementURL: managementURL, + ConfigPath: configPath, + LogFiles: logFiles, + DisableProfiles: profilesDisabled, + DisableUpdateSettings: updateSettingsDisabled, + } + + if len(serviceEnvVars) > 0 { + parsed, err := parseServiceEnvVars(serviceEnvVars) + if err == nil && len(parsed) > 0 { + params.ServiceEnvVars = parsed + } + } + + return params +} + +// loadAndApplyServiceParams loads saved params from disk and applies them +// to any flags that were not explicitly set. +func loadAndApplyServiceParams(cmd *cobra.Command) error { + params, err := loadServiceParams() + if err != nil { + return err + } + applyServiceParams(cmd, params) + return nil +} + +// applyServiceParams merges saved parameters into package-level variables +// for any flag that was not explicitly set by the user (via CLI or env var). +// Flags that were Changed() are left untouched. +func applyServiceParams(cmd *cobra.Command, params *serviceParams) { + if params == nil { + return + } + + // For fields with non-empty defaults (log-level, daemon-addr), keep the + // != "" guard so that an older service.json missing the field doesn't + // clobber the default with an empty string. + if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" { + logLevel = params.LogLevel + } + + if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" { + daemonAddr = params.DaemonAddr + } + + // For optional fields where empty means "use default", always apply so + // that an explicit clear (--management-url "") persists across reinstalls. + if !rootCmd.PersistentFlags().Changed("management-url") { + managementURL = params.ManagementURL + } + + if !rootCmd.PersistentFlags().Changed("config") { + configPath = params.ConfigPath + } + + if !rootCmd.PersistentFlags().Changed("log-file") { + logFiles = params.LogFiles + } + + if !serviceCmd.PersistentFlags().Changed("disable-profiles") { + profilesDisabled = params.DisableProfiles + } + + if !serviceCmd.PersistentFlags().Changed("disable-update-settings") { + updateSettingsDisabled = params.DisableUpdateSettings + } + + applyServiceEnvParams(cmd, params) +} + +// applyServiceEnvParams merges saved service environment variables. +// If --service-env was explicitly set, explicit values win on key conflict +// but saved keys not in the explicit set are carried over. +// If --service-env was not set, saved env vars are used entirely. +func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) { + if len(params.ServiceEnvVars) == 0 { + return + } + + if !cmd.Flags().Changed("service-env") { + // No explicit env vars: rebuild serviceEnvVars from saved params. + serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + return + } + + // Explicit env vars were provided: merge saved values underneath. + explicit, err := parseServiceEnvVars(serviceEnvVars) + if err != nil { + cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err) + return + } + + merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit)) + maps.Copy(merged, params.ServiceEnvVars) + maps.Copy(merged, explicit) // explicit wins on conflict + serviceEnvVars = envMapToSlice(merged) +} + +var resetParamsCmd = &cobra.Command{ + Use: "reset-params", + Short: "Remove saved service install parameters", + Long: "Removes the saved service.json file so the next install uses default parameters.", + RunE: func(cmd *cobra.Command, args []string) error { + path := serviceParamsPath() + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + cmd.Println("No saved service parameters found") + return nil + } + return fmt.Errorf("remove service params: %w", err) + } + cmd.Printf("Removed saved service parameters (%s)\n", path) + return nil + }, +} + +// envMapToSlice converts a map of env vars to a KEY=VALUE slice. +func envMapToSlice(m map[string]string) []string { + s := make([]string, 0, len(m)) + for k, v := range m { + s = append(s, k+"="+v) + } + return s +} diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go new file mode 100644 index 000000000..684593a00 --- /dev/null +++ b/client/cmd/service_params_test.go @@ -0,0 +1,523 @@ +//go:build !ios && !android + +package cmd + +import ( + "encoding/json" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/configs" +) + +func TestServiceParamsPath(t *testing.T) { + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + + configs.StateDir = "/var/lib/netbird" + assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath()) + + configs.StateDir = "/custom/state" + assert.Equal(t, "/custom/state/service.json", serviceParamsPath()) +} + +func TestSaveAndLoadServiceParams(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + params := &serviceParams{ + LogLevel: "debug", + DaemonAddr: "unix:///var/run/netbird.sock", + ManagementURL: "https://my.server.com", + ConfigPath: "/etc/netbird/config.json", + LogFiles: []string{"/var/log/netbird/client.log", "console"}, + DisableProfiles: true, + DisableUpdateSettings: false, + ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"}, + } + + err := saveServiceParams(params) + require.NoError(t, err) + + // Verify the file exists and is valid JSON. + data, err := os.ReadFile(filepath.Join(tmpDir, "service.json")) + require.NoError(t, err) + assert.True(t, json.Valid(data)) + + loaded, err := loadServiceParams() + require.NoError(t, err) + require.NotNil(t, loaded) + + assert.Equal(t, params.LogLevel, loaded.LogLevel) + assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr) + assert.Equal(t, params.ManagementURL, loaded.ManagementURL) + assert.Equal(t, params.ConfigPath, loaded.ConfigPath) + assert.Equal(t, params.LogFiles, loaded.LogFiles) + assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles) + assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings) + assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars) +} + +func TestLoadServiceParams_FileNotExists(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + params, err := loadServiceParams() + assert.NoError(t, err) + assert.Nil(t, params) +} + +func TestLoadServiceParams_InvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + + original := configs.StateDir + t.Cleanup(func() { configs.StateDir = original }) + configs.StateDir = tmpDir + + err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600) + require.NoError(t, err) + + params, err := loadServiceParams() + assert.Error(t, err) + assert.Nil(t, params) +} + +func TestCurrentServiceParams(t *testing.T) { + origLogLevel := logLevel + origDaemonAddr := daemonAddr + origManagementURL := managementURL + origConfigPath := configPath + origLogFiles := logFiles + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { + logLevel = origLogLevel + daemonAddr = origDaemonAddr + managementURL = origManagementURL + configPath = origConfigPath + logFiles = origLogFiles + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + serviceEnvVars = origServiceEnvVars + }) + + logLevel = "trace" + daemonAddr = "tcp://127.0.0.1:9999" + managementURL = "https://mgmt.example.com" + configPath = "/tmp/test-config.json" + logFiles = []string{"/tmp/test.log"} + profilesDisabled = true + updateSettingsDisabled = true + serviceEnvVars = []string{"FOO=bar", "BAZ=qux"} + + params := currentServiceParams() + + assert.Equal(t, "trace", params.LogLevel) + assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr) + assert.Equal(t, "https://mgmt.example.com", params.ManagementURL) + assert.Equal(t, "/tmp/test-config.json", params.ConfigPath) + assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles) + assert.True(t, params.DisableProfiles) + assert.True(t, params.DisableUpdateSettings) + assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars) +} + +func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) { + origLogLevel := logLevel + origDaemonAddr := daemonAddr + origManagementURL := managementURL + origConfigPath := configPath + origLogFiles := logFiles + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { + logLevel = origLogLevel + daemonAddr = origDaemonAddr + managementURL = origManagementURL + configPath = origConfigPath + logFiles = origLogFiles + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + serviceEnvVars = origServiceEnvVars + }) + + // Reset all flags to defaults. + logLevel = "info" + daemonAddr = "unix:///var/run/netbird.sock" + managementURL = "" + configPath = "/etc/netbird/config.json" + logFiles = []string{"/var/log/netbird/client.log"} + profilesDisabled = false + updateSettingsDisabled = false + serviceEnvVars = nil + + // Reset Changed state on all relevant flags. + rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + // Simulate user explicitly setting --log-level via CLI. + logLevel = "warn" + require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn")) + + saved := &serviceParams{ + LogLevel: "debug", + DaemonAddr: "tcp://127.0.0.1:5555", + ManagementURL: "https://saved.example.com", + ConfigPath: "/saved/config.json", + LogFiles: []string{"/saved/client.log"}, + DisableProfiles: true, + DisableUpdateSettings: true, + ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"}, + } + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + // log-level was Changed, so it should keep "warn", not use saved "debug". + assert.Equal(t, "warn", logLevel) + + // All other fields were not Changed, so they should use saved values. + assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr) + assert.Equal(t, "https://saved.example.com", managementURL) + assert.Equal(t, "/saved/config.json", configPath) + assert.Equal(t, []string{"/saved/client.log"}, logFiles) + assert.True(t, profilesDisabled) + assert.True(t, updateSettingsDisabled) + assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars) +} + +func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) { + origProfilesDisabled := profilesDisabled + origUpdateSettingsDisabled := updateSettingsDisabled + t.Cleanup(func() { + profilesDisabled = origProfilesDisabled + updateSettingsDisabled = origUpdateSettingsDisabled + }) + + // Simulate current state where booleans are true (e.g. set by previous install). + profilesDisabled = true + updateSettingsDisabled = true + + // Reset Changed state so flags appear unset. + serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + // Saved params have both as false. + saved := &serviceParams{ + DisableProfiles: false, + DisableUpdateSettings: false, + } + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + assert.False(t, profilesDisabled, "saved false should override current true") + assert.False(t, updateSettingsDisabled, "saved false should override current true") +} + +func TestApplyServiceParams_ClearManagementURL(t *testing.T) { + origManagementURL := managementURL + t.Cleanup(func() { managementURL = origManagementURL }) + + managementURL = "https://leftover.example.com" + + // Simulate saved params where management URL was explicitly cleared. + saved := &serviceParams{ + LogLevel: "info", + DaemonAddr: "unix:///var/run/netbird.sock", + // ManagementURL intentionally empty: was cleared with --management-url "". + } + + rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + }) + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + applyServiceParams(cmd, saved) + + assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value") +} + +func TestApplyServiceParams_NilParams(t *testing.T) { + origLogLevel := logLevel + t.Cleanup(func() { logLevel = origLogLevel }) + + logLevel = "info" + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + + // Should be a no-op. + applyServiceParams(cmd, nil) + assert.Equal(t, "info", logLevel) +} + +func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Set up a command with --service-env marked as Changed. + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit")) + + serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"} + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{ + "SAVED": "val", + "OVERLAP": "saved", + }, + } + + applyServiceEnvParams(cmd, saved) + + // Parse result for easier assertion. + result, err := parseServiceEnvVars(serviceEnvVars) + require.NoError(t, err) + + assert.Equal(t, "yes", result["EXPLICIT"]) + assert.Equal(t, "val", result["SAVED"]) + // Explicit wins on conflict. + assert.Equal(t, "explicit", result["OVERLAP"]) +} + +func TestApplyServiceEnvParams_NotChanged(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + serviceEnvVars = nil + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"FROM_SAVED": "val"}, + } + + applyServiceEnvParams(cmd, saved) + + result, err := parseServiceEnvVars(serviceEnvVars) + require.NoError(t, err) + assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result) +} + +// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are +// referenced in both currentServiceParams() and applyServiceParams(). If a new field is +// added to serviceParams but not wired into these functions, this test fails. +func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "service_params.go", nil, 0) + require.NoError(t, err) + + // Collect all JSON field names from the serviceParams struct. + structFields := extractStructJSONFields(t, file, "serviceParams") + require.NotEmpty(t, structFields, "failed to find serviceParams struct fields") + + // Collect field names referenced in currentServiceParams and applyServiceParams. + currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields) + applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields) + // applyServiceEnvParams handles ServiceEnvVars indirectly. + applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields) + for k, v := range applyEnvFields { + applyFields[k] = v + } + + for _, field := range structFields { + assert.Contains(t, currentFields, field, + "serviceParams field %q is not captured in currentServiceParams()", field) + assert.Contains(t, applyFields, field, + "serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field) + } +} + +// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references +// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because +// it flows through newSVCConfig() EnvVars, not CLI args. +func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "service_params.go", nil, 0) + require.NoError(t, err) + + structFields := extractStructJSONFields(t, file, "serviceParams") + require.NotEmpty(t, structFields) + + installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0) + require.NoError(t, err) + + // Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig). + fieldsNotInArgs := map[string]bool{ + "ServiceEnvVars": true, + } + + buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments") + + // Forward: every struct field must appear in buildServiceArguments. + for _, field := range structFields { + if fieldsNotInArgs[field] { + continue + } + globalVar := fieldToGlobalVar(field) + assert.Contains(t, buildFields, globalVar, + "serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar) + } + + // Reverse: every service-related global used in buildServiceArguments must + // have a corresponding serviceParams field. This catches a developer adding + // a new flag to buildServiceArguments without adding it to the struct. + globalToField := make(map[string]string, len(structFields)) + for _, field := range structFields { + globalToField[fieldToGlobalVar(field)] = field + } + // Identifiers in buildServiceArguments that are not service params + // (builtins, boilerplate, loop variables). + nonParamGlobals := map[string]bool{ + "args": true, "append": true, "string": true, "_": true, + "logFile": true, // range variable over logFiles + } + for ref := range buildFields { + if nonParamGlobals[ref] { + continue + } + _, inStruct := globalToField[ref] + assert.True(t, inStruct, + "buildServiceArguments() references global %q which has no corresponding serviceParams field", ref) + } +} + +// extractStructJSONFields returns field names from a named struct type. +func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string { + t.Helper() + var fields []string + ast.Inspect(file, func(n ast.Node) bool { + ts, ok := n.(*ast.TypeSpec) + if !ok || ts.Name.Name != structName { + return true + } + st, ok := ts.Type.(*ast.StructType) + if !ok { + return false + } + for _, f := range st.Fields.List { + if len(f.Names) > 0 { + fields = append(fields, f.Names[0].Name) + } + } + return false + }) + return fields +} + +// extractFuncFieldRefs returns which of the given field names appear inside the +// named function, either as selector expressions (params.FieldName) or as +// composite literal keys (&serviceParams{FieldName: ...}). +func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool { + t.Helper() + fieldSet := make(map[string]bool, len(fields)) + for _, f := range fields { + fieldSet[f] = true + } + + found := make(map[string]bool) + fn := findFuncDecl(file, funcName) + require.NotNil(t, fn, "function %s not found", funcName) + + ast.Inspect(fn.Body, func(n ast.Node) bool { + switch v := n.(type) { + case *ast.SelectorExpr: + if fieldSet[v.Sel.Name] { + found[v.Sel.Name] = true + } + case *ast.KeyValueExpr: + if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] { + found[ident.Name] = true + } + } + return true + }) + return found +} + +// extractFuncGlobalRefs returns all identifier names referenced in the named function body. +func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool { + t.Helper() + fn := findFuncDecl(file, funcName) + require.NotNil(t, fn, "function %s not found", funcName) + + refs := make(map[string]bool) + ast.Inspect(fn.Body, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + refs[ident.Name] = true + } + return true + }) + return refs +} + +func findFuncDecl(file *ast.File, name string) *ast.FuncDecl { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if ok && fn.Name.Name == name { + return fn + } + } + return nil +} + +// fieldToGlobalVar maps serviceParams field names to the package-level variable +// names used in buildServiceArguments and applyServiceParams. +func fieldToGlobalVar(field string) string { + m := map[string]string{ + "LogLevel": "logLevel", + "DaemonAddr": "daemonAddr", + "ManagementURL": "managementURL", + "ConfigPath": "configPath", + "LogFiles": "logFiles", + "DisableProfiles": "profilesDisabled", + "DisableUpdateSettings": "updateSettingsDisabled", + "ServiceEnvVars": "serviceEnvVars", + } + if v, ok := m[field]; ok { + return v + } + // Default: lowercase first letter. + return strings.ToLower(field[:1]) + field[1:] +} + +func TestEnvMapToSlice(t *testing.T) { + m := map[string]string{"A": "1", "B": "2"} + s := envMapToSlice(m) + assert.Len(t, s, 2) + assert.Contains(t, s, "A=1") + assert.Contains(t, s, "B=2") +} + +func TestEnvMapToSlice_Empty(t *testing.T) { + s := envMapToSlice(map[string]string{}) + assert.Empty(t, s) +} diff --git a/client/cmd/status.go b/client/cmd/status.go index f09c35c2c..c35a06eb3 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -28,6 +28,7 @@ var ( ipsFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{} connectionTypeFilter string + checkFlag string ) var statusCmd = &cobra.Command{ @@ -49,6 +50,7 @@ func init() { statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") + statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") } func statusFunc(cmd *cobra.Command, args []string) error { @@ -56,6 +58,10 @@ func statusFunc(cmd *cobra.Command, args []string) error { cmd.SetOut(cmd.OutOrStdout()) + if checkFlag != "" { + return runHealthCheck(cmd) + } + err := parseFilters() if err != nil { return err @@ -68,15 +74,17 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx, false) + resp, err := getStatus(ctx, true, false) if err != nil { return err } status := resp.GetStatus() - if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) || - status == string(internal.StatusSessionExpired) { + needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) || + status == string(internal.StatusSessionExpired) + + if needsAuth && !jsonFlag && !yamlFlag { cmd.Printf("Daemon status: %s\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+ " netbird up \n\n"+ @@ -99,7 +107,17 @@ func statusFunc(cmd *cobra.Command, args []string) error { profName = activeProf.Name } - var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) + var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{ + Anonymize: anonymizeFlag, + DaemonVersion: resp.GetDaemonVersion(), + DaemonStatus: nbstatus.ParseDaemonStatus(status), + StatusFilter: statusFilter, + PrefixNamesFilter: prefixNamesFilter, + PrefixNamesFilterMap: prefixNamesFilterMap, + IPsFilter: ipsFilterMap, + ConnectionTypeFilter: connectionTypeFilter, + ProfileName: profName, + }) var statusOutputString string switch { case detailFlag: @@ -121,7 +139,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { //nolint @@ -131,7 +149,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } @@ -185,6 +203,83 @@ func enableDetailFlagWhenFilterFlag() { } } +func runHealthCheck(cmd *cobra.Command) error { + check := strings.ToLower(checkFlag) + switch check { + case "live", "ready", "startup": + default: + return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag) + } + + if err := util.InitLog(logLevel, util.LogConsole); err != nil { + return fmt.Errorf("init log: %w", err) + } + + ctx := internal.CtxInitState(cmd.Context()) + + isStartup := check == "startup" + resp, err := getStatus(ctx, isStartup, false) + if err != nil { + return err + } + + switch check { + case "live": + return nil + case "ready": + return checkReadiness(resp) + case "startup": + return checkStartup(resp) + default: + return nil + } +} + +func checkReadiness(resp *proto.StatusResponse) error { + daemonStatus := internal.StatusType(resp.GetStatus()) + switch daemonStatus { + case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected: + return nil + case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired: + return fmt.Errorf("readiness check: daemon status is %s", daemonStatus) + default: + return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus) + } +} + +func checkStartup(resp *proto.StatusResponse) error { + fullStatus := resp.GetFullStatus() + if fullStatus == nil { + return fmt.Errorf("startup check: no full status available") + } + + if !fullStatus.GetManagementState().GetConnected() { + return fmt.Errorf("startup check: management not connected") + } + + if !fullStatus.GetSignalState().GetConnected() { + return fmt.Errorf("startup check: signal not connected") + } + + var relayCount, relaysConnected int + for _, r := range fullStatus.GetRelays() { + uri := r.GetURI() + if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") { + continue + } + relayCount++ + if r.GetAvailable() { + relaysConnected++ + } + } + + if relayCount > 0 && relaysConnected == 0 { + return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount) + } + + return nil +} + func parseInterfaceIP(interfaceIP string) string { ip, _, err := net.ParseCIDR(interfaceIP) if err != nil { diff --git a/client/embed/embed.go b/client/embed/embed.go index 70013989a..88f7e541c 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" sshcommon "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -32,14 +33,14 @@ var ( ErrConfigNotInitialized = errors.New("config not initialized") ) -// PeerConnStatus is a peer's connection status. -type PeerConnStatus = peer.ConnStatus - const ( // PeerStatusConnected indicates the peer is in connected state. PeerStatusConnected = peer.StatusConnected ) +// PeerConnStatus is a peer's connection status. +type PeerConnStatus = peer.ConnStatus + // Client manages a netbird embedded client instance. type Client struct { deviceName string @@ -88,6 +89,8 @@ type Options struct { // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. // Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams. MTU *uint16 + // DNSLabels defines additional DNS labels configured in the peer. + DNSLabels []string } // validateCredentials checks that exactly one credential type is provided @@ -153,9 +156,14 @@ func New(opts Options) (*Client, error) { } } + var err error + var parsedLabels domain.List + if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil { + return nil, fmt.Errorf("invalid dns labels: %w", err) + } + t := true var config *profilemanager.Config - var err error input := profilemanager.ConfigInput{ ConfigPath: opts.ConfigPath, ManagementURL: opts.ManagementURL, @@ -165,6 +173,7 @@ func New(opts Options) (*Client, error) { BlockInbound: &opts.BlockInbound, WireguardPort: opts.WireguardPort, MTU: opts.MTU, + DNSLabels: parsedLabels, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) @@ -366,6 +375,32 @@ func (c *Client) NewHTTPClient() *http.Client { } } +// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL. +// It returns an ExposeSession. Call Wait on the session to keep it alive. +func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + mgr := engine.GetExposeManager() + if mgr == nil { + return nil, fmt.Errorf("expose manager not available") + } + + resp, err := mgr.Expose(ctx, req) + if err != nil { + return nil, fmt.Errorf("expose: %w", err) + } + + return &ExposeSession{ + Domain: resp.Domain, + ServiceName: resp.ServiceName, + ServiceURL: resp.ServiceURL, + mgr: mgr, + }, nil +} + // Status returns the current status of the client. func (c *Client) Status() (peer.FullStatus, error) { c.mu.Lock() diff --git a/client/embed/expose.go b/client/embed/expose.go new file mode 100644 index 000000000..825bb90ee --- /dev/null +++ b/client/embed/expose.go @@ -0,0 +1,45 @@ +package embed + +import ( + "context" + "errors" + + "github.com/netbirdio/netbird/client/internal/expose" +) + +const ( + // ExposeProtocolHTTP exposes the service as HTTP. + ExposeProtocolHTTP = expose.ProtocolHTTP + // ExposeProtocolHTTPS exposes the service as HTTPS. + ExposeProtocolHTTPS = expose.ProtocolHTTPS + // ExposeProtocolTCP exposes the service as TCP. + ExposeProtocolTCP = expose.ProtocolTCP + // ExposeProtocolUDP exposes the service as UDP. + ExposeProtocolUDP = expose.ProtocolUDP + // ExposeProtocolTLS exposes the service as TLS. + ExposeProtocolTLS = expose.ProtocolTLS +) + +// ExposeRequest is a request to expose a local service via the NetBird reverse proxy. +type ExposeRequest = expose.Request + +// ExposeProtocolType represents the protocol used for exposing a service. +type ExposeProtocolType = expose.ProtocolType + +// ExposeSession represents an active expose session. Use Wait to block until the session ends. +type ExposeSession struct { + Domain string + ServiceName string + ServiceURL string + + mgr *expose.Manager +} + +// Wait blocks while keeping the expose session alive. +// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session. +func (s *ExposeSession) Wait(ctx context.Context) error { + if s == nil || s.mgr == nil { + return errors.New("expose session is not initialized") + } + return s.mgr.KeepAlive(ctx, s.Domain) +} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 716385705..04c338375 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -23,9 +23,10 @@ type Manager struct { wgIface iFaceMapper - ipv4Client *iptables.IPTables - aclMgr *aclManager - router *router + ipv4Client *iptables.IPTables + aclMgr *aclManager + router *router + rawSupported bool } // iFaceMapper defines subset methods of interface required for manager @@ -84,7 +85,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.initNoTrackChain(); err != nil { - return fmt.Errorf("init notrack chain: %w", err) + log.Warnf("raw table not available, notrack rules will be disabled: %v", err) } // persist early to ensure cleanup of chains @@ -318,6 +319,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() + if !m.rawSupported { + return fmt.Errorf("raw table not available") + } + wgPortStr := fmt.Sprintf("%d", wgPort) proxyPortStr := fmt.Sprintf("%d", proxyPort) @@ -375,12 +380,16 @@ func (m *Manager) initNoTrackChain() error { return fmt.Errorf("add prerouting jump rule: %w", err) } + m.rawSupported = true return nil } func (m *Manager) cleanupNoTrackChain() error { exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw) if err != nil { + if !m.rawSupported { + return nil + } return fmt.Errorf("check chain exists: %w", err) } if !exists { @@ -401,6 +410,7 @@ func (m *Manager) cleanupNoTrackChain() error { return fmt.Errorf("clear and delete chain: %w", err) } + m.rawSupported = false return nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index acf482f86..f57b28abc 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.initNoTrackChains(workTable); err != nil { - return fmt.Errorf("init notrack chains: %w", err) + log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } stateManager.RegisterState(&ShutdownState{}) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 54966b50e..9a6bc0670 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff { // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) // for js, the outer websocket layer takes care of tls if tlsEnabled && runtime.GOOS != "js" { @@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - conn, err := grpc.DialContext( - connCtx, - addr, + opts := []grpc.DialOption{ transportOption, WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), @@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone Time: 30 * time.Second, Timeout: 10 * time.Second, }), - ) + } + opts = append(opts, extraOpts...) + + conn, err := grpc.DialContext(connCtx, addr, opts...) if err != nil { return nil, fmt.Errorf("dial context: %w", err) } diff --git a/client/internal/connect.go b/client/internal/connect.go index ccd7b6c33..242b25b44 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -50,6 +51,7 @@ type ConnectClient struct { engine *Engine engineMutex sync.Mutex + clientMetrics *metrics.ClientMetrics updateManager *updater.Manager persistSyncResponse bool @@ -133,10 +135,34 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } }() + // Stop metrics push on exit + defer func() { + if c.clientMetrics != nil { + c.clientMetrics.StopPush() + } + }() + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) nbnet.Init() + // Initialize metrics once at startup (always active for debug bundles) + if c.clientMetrics == nil { + agentInfo := metrics.AgentInfo{ + DeploymentType: metrics.DeploymentTypeUnknown, + Version: version.NetbirdVersion(), + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + c.clientMetrics = metrics.NewClientMetrics(agentInfo) + log.Debugf("initialized client metrics") + + // Start metrics push if enabled (uses daemon context, persists across engine restarts) + if metrics.IsMetricsPushEnabled() { + c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv()) + } + } + backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -223,6 +249,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) mgmClient.SetConnStateListener(mgmNotifier) + // Update metrics with actual deployment type after connection + deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL()) + agentInfo := metrics.AgentInfo{ + DeploymentType: deploymentType, + Version: version.NetbirdVersion(), + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String()) + log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) defer func() { if err = mgmClient.Close(); err != nil { @@ -231,8 +267,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan }() // connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config + loginStarted := time.Now() loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config) if err != nil { + c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false) log.Debug(err) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { state.Set(StatusNeedsLogin) @@ -241,6 +279,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } return wrapErr(err) } + c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true) c.statusRecorder.MarkManagementConnected() localPeerState := peer.LocalPeerState{ @@ -317,6 +356,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan Checks: checks, StateManager: stateManager, UpdateManager: c.updateManager, + ClientMetrics: c.clientMetrics, }, mobileDependency) engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine = engine diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index f0f399bef..c9ebf25e5 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -31,7 +31,6 @@ import ( nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/version" ) const readmeContent = `Netbird debug bundle @@ -53,6 +52,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states for the active profile. +metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -219,6 +219,11 @@ const ( darwinStdoutLogPath = "/var/log/netbird.err.log" ) +// MetricsExporter is an interface for exporting metrics +type MetricsExporter interface { + Export(w io.Writer) error +} + type BundleGenerator struct { anonymizer *anonymize.Anonymizer @@ -229,6 +234,7 @@ type BundleGenerator struct { logPath string cpuProfile []byte refreshStatus func() // Optional callback to refresh status before bundle generation + clientMetrics MetricsExporter anonymize bool includeSystemInfo bool @@ -250,6 +256,7 @@ type GeneratorDependencies struct { LogPath string CPUProfile []byte RefreshStatus func() // Optional callback to refresh status before bundle generation + ClientMetrics MetricsExporter } func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { @@ -268,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen logPath: deps.LogPath, cpuProfile: deps.CPUProfile, refreshStatus: deps.RefreshStatus, + clientMetrics: deps.ClientMetrics, anonymize: cfg.Anonymize, includeSystemInfo: cfg.IncludeSystemInfo, @@ -351,6 +359,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add corrupted state files to debug bundle: %v", err) } + if err := g.addMetrics(); err != nil { + log.Errorf("failed to add metrics to debug bundle: %v", err) + } + if err := g.addWgShow(); err != nil { log.Errorf("failed to add wg show output: %v", err) } @@ -418,7 +430,10 @@ func (g *BundleGenerator) addStatus() error { fullStatus := g.statusRecorder.GetFullStatus() protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus) protoFullStatus.Events = g.statusRecorder.GetEventHistory() - overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName) + overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{ + Anonymize: g.anonymize, + ProfileName: profName, + }) statusOutput := overview.FullDetailSummary() statusReader := strings.NewReader(statusOutput) @@ -744,6 +759,30 @@ func (g *BundleGenerator) addCorruptedStateFiles() error { return nil } +func (g *BundleGenerator) addMetrics() error { + if g.clientMetrics == nil { + log.Debugf("skipping metrics in debug bundle: no metrics collector") + return nil + } + + var buf bytes.Buffer + if err := g.clientMetrics.Export(&buf); err != nil { + return fmt.Errorf("export metrics: %w", err) + } + + if buf.Len() == 0 { + log.Debugf("skipping metrics.txt in debug bundle: no metrics data") + return nil + } + + if err := g.addFileToZip(&buf, "metrics.txt"); err != nil { + return fmt.Errorf("add metrics file to zip: %w", err) + } + + log.Debugf("added metrics to debug bundle") + return nil +} + func (g *BundleGenerator) addLogfile() error { if g.logPath == "" { log.Debugf("skipping empty log file in debug bundle") diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index fe160e20a..1df57d1db 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -85,6 +85,11 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } +// SetRouteChecker mock implementation of SetRouteChecker from Server interface +func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { + // Mock implementation - no-op +} + // BeginBatch mock implementation of BeginBatch from Server interface func (m *MockServer) BeginBatch() { // Mock implementation - no-op diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6ca4f7957..3c47f4ee6 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -57,6 +57,7 @@ type Server interface { ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error + SetRouteChecker(func(netip.Addr) bool) } type nsGroupsByDomain struct { @@ -104,6 +105,7 @@ type DefaultServer struct { statusRecorder *peer.Status stateManager *statemanager.Manager + routeMatch func(netip.Addr) bool probeMu sync.Mutex probeCancel context.CancelFunc @@ -229,6 +231,14 @@ func newDefaultServer( return defaultServer } +// SetRouteChecker sets the function used by upstream resolvers to determine +// whether an IP is routed through the tunnel. +func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) { + s.mux.Lock() + defer s.mux.Unlock() + s.routeMatch = f +} + // RegisterHandler registers a handler for the given domains with the given priority. // Any previously registered handler for the same domain and priority will be replaced. func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -743,6 +753,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) return } + handler.routeMatch = s.routeMatch for _, ns := range originalNameservers { if ns == config.ServerIP { @@ -852,6 +863,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai if err != nil { return nil, fmt.Errorf("create upstream resolver: %v", err) } + handler.routeMatch = s.routeMatch for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { @@ -1036,6 +1048,7 @@ func (s *DefaultServer) addHostRootZone() { log.Errorf("unable to create a new upstream resolver, error: %v", err) return } + handler.routeMatch = s.routeMatch handler.upstreamServers = maps.Keys(hostDNSServers) handler.deactivate = func(error) {} diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 806559444..f7ddfd40f 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "runtime" + "strconv" "sync" "time" @@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error { return fmt.Errorf("eval listen address: %w", err) } s.listenIP = s.listenIP.Unmap() - s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) + s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort))) log.Debugf("starting dns on %s", s.server.Addr) go func() { s.setListenerStatus(true) @@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { } func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { - addrString := fmt.Sprintf("%s:%d", ip, port) + addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port)) udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) probeListener, err := net.ListenUDP("udp", udpAddr) if err != nil { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 18128a942..5b8135132 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -70,6 +70,7 @@ type upstreamResolverBase struct { deactivate func(error) reactivate func() statusRecorder *peer.Status + routeMatch func(netip.Addr) bool } type upstreamFailure struct { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 4d053a5a1..02c11173b 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { - log.Debugf("using private client to query upstream: %s", upstream) + needsPrivate := u.lNet.Contains(upstreamIP) || + (u.routeMatch != nil && u.routeMatch(upstreamIP)) + if needsPrivate { + log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) if err != nil { - return nil, 0, fmt.Errorf("error while creating private client: %s", err) + return nil, 0, fmt.Errorf("create private client: %s", err) } } diff --git a/client/internal/engine.go b/client/internal/engine.go index 3d72de5a7..af579f9b4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/expose" "github.com/netbirdio/netbird/client/internal/ingressgw" + "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/netflow" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/networkmonitor" @@ -150,6 +151,7 @@ type EngineServices struct { Checks []*mgmProto.Checks StateManager *statemanager.Manager UpdateManager *updater.Manager + ClientMetrics *metrics.ClientMetrics } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -231,6 +233,9 @@ type Engine struct { probeStunTurn *relay.StunTurnProbe + // clientMetrics collects and pushes metrics + clientMetrics *metrics.ClientMetrics + jobExecutor *jobexec.Executor jobExecutorWG sync.WaitGroup @@ -274,7 +279,9 @@ func NewEngine( portForwardManager: portforward.NewManager(), checks: services.Checks, probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), - jobExecutor: jobexec.NewExecutor(), updateManager: services.UpdateManager, + jobExecutor: jobexec.NewExecutor(), + clientMetrics: services.ClientMetrics, + updateManager: services.UpdateManager, } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -495,6 +502,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) + e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { + for _, routes := range e.routeManager.GetClientRoutes() { + for _, r := range routes { + if r.Network.Contains(ip) { + return true + } + } + } + return false + }) + if err = e.wgInterfaceCreate(); err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) e.close() @@ -822,7 +840,9 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { started := time.Now() defer func() { - log.Infof("sync finished in %s", time.Since(started)) + duration := time.Since(started) + log.Infof("sync finished in %s", duration) + e.clientMetrics.RecordSyncDuration(e.ctx, duration) }() e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -998,10 +1018,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return errors.New("wireguard interface is not initialized") } - // Cannot update the IP address without restarting the engine because - // the firewall, route manager, and other components cache the old address if e.wgInterface.Address().String() != conf.Address { - log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address) + log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address) + _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) + e.clientCancel() + return ErrResetConnection } if conf.GetSshConfig() != nil { @@ -1069,6 +1090,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR StatusRecorder: e.statusRecorder, SyncResponse: syncResponse, LogPath: e.config.LogPath, + ClientMetrics: e.clientMetrics, RefreshStatus: func() { e.RunHealthProbes(true) }, @@ -1529,6 +1551,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV RelayManager: e.relayManager, SrWatcher: e.srWatcher, PortForwardManager: e.portForwardManager, + MetricsRecorder: e.clientMetrics, } peerConn, err := peer.NewConn(config, serviceDependencies) if err != nil { @@ -1831,6 +1854,11 @@ func (e *Engine) GetExposeManager() *expose.Manager { return e.exposeManager } +// GetClientMetrics returns the client metrics +func (e *Engine) GetClientMetrics() *metrics.ClientMetrics { + return e.clientMetrics +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f9e7f8fa0..77fe9049b 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, EngineServices{ + }, EngineServices{ SignalClient: &signal.MockClient{}, MgmClient: &mgmt.MockClient{}, RelayManager: relayMgr, @@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, - }, EngineServices{ + }, EngineServices{ SignalClient: &signal.MockClient{}, MgmClient: &mgmt.MockClient{}, RelayManager: relayMgr, @@ -1566,7 +1566,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - e, err := NewEngine(ctx, cancel, conf, EngineServices{ +e, err := NewEngine(ctx, cancel, conf, EngineServices{ SignalClient: signalClient, MgmClient: mgmtClient, RelayManager: relayMgr, diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go index c59a1a7bd..076f92043 100644 --- a/client/internal/expose/manager.go +++ b/client/internal/expose/manager.go @@ -4,11 +4,14 @@ import ( "context" "time" - mgm "github.com/netbirdio/netbird/shared/management/client" log "github.com/sirupsen/logrus" + + mgm "github.com/netbirdio/netbird/shared/management/client" ) -const renewTimeout = 10 * time.Second +const ( + renewTimeout = 10 * time.Second +) // Response holds the response from exposing a service. type Response struct { @@ -18,11 +21,13 @@ type Response struct { PortAutoAssigned bool } +// Request holds the parameters for exposing a local service via the management server. +// It is part of the embed API surface and exposed via a type alias. type Request struct { NamePrefix string Domain string Port uint16 - Protocol int + Protocol ProtocolType Pin string Password string UserGroups []string @@ -59,6 +64,8 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) { return fromClientExposeResponse(resp), nil } +// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs. +// It is part of the embed API surface and exposed via a type alias. func (m *Manager) KeepAlive(ctx context.Context, domain string) error { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() diff --git a/client/internal/expose/manager_test.go b/client/internal/expose/manager_test.go index 87d43cdb0..7d76c9838 100644 --- a/client/internal/expose/manager_test.go +++ b/client/internal/expose/manager_test.go @@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) { exposeReq := NewRequest(req) assert.Equal(t, uint16(8080), exposeReq.Port, "port should match") - assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match") + assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match") assert.Equal(t, "123456", exposeReq.Pin, "pin should match") assert.Equal(t, "secret", exposeReq.Password, "password should match") assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match") diff --git a/client/internal/expose/protocol.go b/client/internal/expose/protocol.go new file mode 100644 index 000000000..d5026d51e --- /dev/null +++ b/client/internal/expose/protocol.go @@ -0,0 +1,40 @@ +package expose + +import ( + "fmt" + "strings" +) + +// ProtocolType represents the protocol used for exposing a service. +type ProtocolType int + +const ( + // ProtocolHTTP exposes the service as HTTP. + ProtocolHTTP ProtocolType = 0 + // ProtocolHTTPS exposes the service as HTTPS. + ProtocolHTTPS ProtocolType = 1 + // ProtocolTCP exposes the service as TCP. + ProtocolTCP ProtocolType = 2 + // ProtocolUDP exposes the service as UDP. + ProtocolUDP ProtocolType = 3 + // ProtocolTLS exposes the service as TLS. + ProtocolTLS ProtocolType = 4 +) + +// ParseProtocolType parses a protocol string into a ProtocolType. +func ParseProtocolType(s string) (ProtocolType, error) { + switch strings.ToLower(s) { + case "http": + return ProtocolHTTP, nil + case "https": + return ProtocolHTTPS, nil + case "tcp": + return ProtocolTCP, nil + case "udp": + return ProtocolUDP, nil + case "tls": + return ProtocolTLS, nil + default: + return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s) + } +} diff --git a/client/internal/expose/request.go b/client/internal/expose/request.go index bff4f2ce7..ec75bb276 100644 --- a/client/internal/expose/request.go +++ b/client/internal/expose/request.go @@ -9,7 +9,7 @@ import ( func NewRequest(req *daemonProto.ExposeServiceRequest) *Request { return &Request{ Port: uint16(req.Port), - Protocol: int(req.Protocol), + Protocol: ProtocolType(req.Protocol), Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, @@ -24,7 +24,7 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest { NamePrefix: req.NamePrefix, Domain: req.Domain, Port: req.Port, - Protocol: req.Protocol, + Protocol: int(req.Protocol), Pin: req.Pin, Password: req.Password, UserGroups: req.UserGroups, diff --git a/client/internal/metrics/connection_type.go b/client/internal/metrics/connection_type.go new file mode 100644 index 000000000..a3406a6b8 --- /dev/null +++ b/client/internal/metrics/connection_type.go @@ -0,0 +1,17 @@ +package metrics + +// ConnectionType represents the type of peer connection +type ConnectionType string + +const ( + // ConnectionTypeICE represents a direct peer-to-peer connection using ICE + ConnectionTypeICE ConnectionType = "ice" + + // ConnectionTypeRelay represents a relayed connection + ConnectionTypeRelay ConnectionType = "relay" +) + +// String returns the string representation of the connection type +func (c ConnectionType) String() string { + return string(c) +} diff --git a/client/internal/metrics/deployment_type.go b/client/internal/metrics/deployment_type.go new file mode 100644 index 000000000..141173cb8 --- /dev/null +++ b/client/internal/metrics/deployment_type.go @@ -0,0 +1,51 @@ +package metrics + +import ( + "net/url" + "strings" +) + +// DeploymentType represents the type of NetBird deployment +type DeploymentType int + +const ( + // DeploymentTypeUnknown represents an unknown or uninitialized deployment type + DeploymentTypeUnknown DeploymentType = iota + + // DeploymentTypeCloud represents a cloud-hosted NetBird deployment + DeploymentTypeCloud + + // DeploymentTypeSelfHosted represents a self-hosted NetBird deployment + DeploymentTypeSelfHosted +) + +// String returns the string representation of the deployment type +func (d DeploymentType) String() string { + switch d { + case DeploymentTypeCloud: + return "cloud" + case DeploymentTypeSelfHosted: + return "selfhosted" + default: + return "unknown" + } +} + +// DetermineDeploymentType determines if the deployment is cloud or self-hosted +// based on the management URL string +func DetermineDeploymentType(managementURL string) DeploymentType { + if managementURL == "" { + return DeploymentTypeUnknown + } + + u, err := url.Parse(managementURL) + if err != nil { + return DeploymentTypeSelfHosted + } + + if strings.ToLower(u.Hostname()) == "api.netbird.io" { + return DeploymentTypeCloud + } + + return DeploymentTypeSelfHosted +} diff --git a/client/internal/metrics/env.go b/client/internal/metrics/env.go new file mode 100644 index 000000000..1f06ce484 --- /dev/null +++ b/client/internal/metrics/env.go @@ -0,0 +1,93 @@ +package metrics + +import ( + "net/url" + "os" + "strconv" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // EnvMetricsPushEnabled controls whether collected metrics are pushed to the backend. + // Metrics collection itself is always active (for debug bundles). + // Disabled by default. Set NB_METRICS_PUSH_ENABLED=true to enable push. + EnvMetricsPushEnabled = "NB_METRICS_PUSH_ENABLED" + + // EnvMetricsForceSending if set to true, skips remote configuration fetch and forces metric sending + EnvMetricsForceSending = "NB_METRICS_FORCE_SENDING" + + // EnvMetricsConfigURL is the environment variable to override the metrics push config ServerAddress + EnvMetricsConfigURL = "NB_METRICS_CONFIG_URL" + + // EnvMetricsServerURL is the environment variable to override the metrics server address. + // When set, this takes precedence over the server_url from remote push config. + EnvMetricsServerURL = "NB_METRICS_SERVER_URL" + + // EnvMetricsInterval overrides the push interval from the remote config. + // Only affects how often metrics are pushed; remote config availability + // and version range checks are still respected. + // Format: duration string like "1h", "30m", "4h" + EnvMetricsInterval = "NB_METRICS_INTERVAL" + + defaultMetricsConfigURL = "https://ingest.netbird.io/config" +) + +// IsMetricsPushEnabled returns true if metrics push is enabled via NB_METRICS_PUSH_ENABLED env var. +// Disabled by default. Metrics collection is always active for debug bundles. +func IsMetricsPushEnabled() bool { + enabled, _ := strconv.ParseBool(os.Getenv(EnvMetricsPushEnabled)) + return enabled +} + +// getMetricsInterval returns the metrics push interval from NB_METRICS_INTERVAL env var. +// Returns 0 if not set or invalid. +func getMetricsInterval() time.Duration { + intervalStr := os.Getenv(EnvMetricsInterval) + if intervalStr == "" { + return 0 + } + interval, err := time.ParseDuration(intervalStr) + if err != nil { + log.Warnf("invalid metrics interval from env %q: %v", intervalStr, err) + return 0 + } + if interval <= 0 { + log.Warnf("invalid metrics interval from env %q: must be positive", intervalStr) + return 0 + } + return interval +} + +func isForceSending() bool { + force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending)) + return force +} + +// getMetricsConfigURL returns the URL to fetch push configuration from +func getMetricsConfigURL() string { + if envURL := os.Getenv(EnvMetricsConfigURL); envURL != "" { + return envURL + } + return defaultMetricsConfigURL +} + +// getMetricsServerURL returns the metrics server URL from NB_METRICS_SERVER_URL env var. +// Returns nil if not set or invalid. +func getMetricsServerURL() *url.URL { + envURL := os.Getenv(EnvMetricsServerURL) + if envURL == "" { + return nil + } + parsed, err := url.ParseRequestURI(envURL) + if err != nil || parsed.Host == "" { + log.Warnf("invalid metrics server URL %q: must be an absolute HTTP(S) URL", envURL) + return nil + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + log.Warnf("invalid metrics server URL %q: unsupported scheme %q", envURL, parsed.Scheme) + return nil + } + return parsed +} diff --git a/client/internal/metrics/influxdb.go b/client/internal/metrics/influxdb.go new file mode 100644 index 000000000..531f6a986 --- /dev/null +++ b/client/internal/metrics/influxdb.go @@ -0,0 +1,219 @@ +package metrics + +import ( + "context" + "fmt" + "io" + "maps" + "slices" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxSampleAge = 5 * 24 * time.Hour // drop samples older than 5 days + maxBufferSize = 5 * 1024 * 1024 // drop oldest samples when estimated size exceeds 5 MB + // estimatedSampleSize is a rough per-sample memory estimate (measurement + tags + fields + timestamp) + estimatedSampleSize = 256 +) + +// influxSample is a single InfluxDB line protocol entry. +type influxSample struct { + measurement string + tags string + fields map[string]float64 + timestamp time.Time +} + +// influxDBMetrics collects metric events as timestamped samples. +// Each event is recorded with its exact timestamp, pushed once, then cleared. +type influxDBMetrics struct { + mu sync.Mutex + samples []influxSample +} + +func newInfluxDBMetrics() metricsImplementation { + return &influxDBMetrics{} +} +func (m *influxDBMetrics) RecordConnectionStages( + _ context.Context, + agentInfo AgentInfo, + connectionPairID string, + connectionType ConnectionType, + isReconnection bool, + timestamps ConnectionStageTimestamps, +) { + var signalingReceivedToConnection, connectionToWgHandshake, totalDuration float64 + + if !timestamps.SignalingReceived.IsZero() && !timestamps.ConnectionReady.IsZero() { + signalingReceivedToConnection = timestamps.ConnectionReady.Sub(timestamps.SignalingReceived).Seconds() + } + + if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() { + connectionToWgHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds() + } + + if !timestamps.SignalingReceived.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() { + totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.SignalingReceived).Seconds() + } + + attemptType := "initial" + if isReconnection { + attemptType = "reconnection" + } + + connTypeStr := connectionType.String() + tags := fmt.Sprintf("deployment_type=%s,connection_type=%s,attempt_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,connection_pair_id=%s", + agentInfo.DeploymentType.String(), + connTypeStr, + attemptType, + agentInfo.Version, + agentInfo.OS, + agentInfo.Arch, + agentInfo.peerID, + connectionPairID, + ) + + now := time.Now() + + m.mu.Lock() + defer m.mu.Unlock() + + m.samples = append(m.samples, influxSample{ + measurement: "netbird_peer_connection", + tags: tags, + fields: map[string]float64{ + "signaling_to_connection_seconds": signalingReceivedToConnection, + "connection_to_wg_handshake_seconds": connectionToWgHandshake, + "total_seconds": totalDuration, + }, + timestamp: now, + }) + m.trimLocked() + + log.Tracef("peer connection metrics [%s, %s, %s]: signalingReceived→connection: %.3fs, connection→wg_handshake: %.3fs, total: %.3fs", + agentInfo.DeploymentType.String(), connTypeStr, attemptType, signalingReceivedToConnection, connectionToWgHandshake, totalDuration) +} + +func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration) { + tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s", + agentInfo.DeploymentType.String(), + agentInfo.Version, + agentInfo.OS, + agentInfo.Arch, + agentInfo.peerID, + ) + + m.mu.Lock() + defer m.mu.Unlock() + + m.samples = append(m.samples, influxSample{ + measurement: "netbird_sync", + tags: tags, + fields: map[string]float64{ + "duration_seconds": duration.Seconds(), + }, + timestamp: time.Now(), + }) + m.trimLocked() +} + +func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) { + result := "success" + if !success { + result = "failure" + } + + tags := fmt.Sprintf("deployment_type=%s,result=%s,version=%s,os=%s,arch=%s,peer_id=%s", + agentInfo.DeploymentType.String(), + result, + agentInfo.Version, + agentInfo.OS, + agentInfo.Arch, + agentInfo.peerID, + ) + + m.mu.Lock() + defer m.mu.Unlock() + + m.samples = append(m.samples, influxSample{ + measurement: "netbird_login", + tags: tags, + fields: map[string]float64{ + "duration_seconds": duration.Seconds(), + }, + timestamp: time.Now(), + }) + m.trimLocked() + + log.Tracef("login metrics [%s, %s]: duration=%.3fs", agentInfo.DeploymentType.String(), result, duration.Seconds()) +} + +// Export writes pending samples in InfluxDB line protocol format. +// Format: measurement,tag=val,tag=val field=val,field=val timestamp_ns +func (m *influxDBMetrics) Export(w io.Writer) error { + m.mu.Lock() + samples := make([]influxSample, len(m.samples)) + copy(samples, m.samples) + m.mu.Unlock() + + for _, s := range samples { + if _, err := fmt.Fprintf(w, "%s,%s ", s.measurement, s.tags); err != nil { + return err + } + + sortedKeys := slices.Sorted(maps.Keys(s.fields)) + first := true + for _, k := range sortedKeys { + if !first { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + if _, err := fmt.Fprintf(w, "%s=%g", k, s.fields[k]); err != nil { + return err + } + first = false + } + + if _, err := fmt.Fprintf(w, " %d\n", s.timestamp.UnixNano()); err != nil { + return err + } + } + return nil +} + +// Reset clears pending samples after a successful push +func (m *influxDBMetrics) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.samples = m.samples[:0] +} + +// trimLocked removes samples that exceed age or size limits. +// Must be called with m.mu held. +func (m *influxDBMetrics) trimLocked() { + now := time.Now() + + // drop samples older than maxSampleAge + cutoff := 0 + for cutoff < len(m.samples) && now.Sub(m.samples[cutoff].timestamp) > maxSampleAge { + cutoff++ + } + if cutoff > 0 { + copy(m.samples, m.samples[cutoff:]) + m.samples = m.samples[:len(m.samples)-cutoff] + log.Debugf("influxdb metrics: dropped %d samples older than %s", cutoff, maxSampleAge) + } + + // drop oldest samples if estimated size exceeds maxBufferSize + maxSamples := maxBufferSize / estimatedSampleSize + if len(m.samples) > maxSamples { + drop := len(m.samples) - maxSamples + copy(m.samples, m.samples[drop:]) + m.samples = m.samples[:maxSamples] + log.Debugf("influxdb metrics: dropped %d oldest samples to stay under %d MB size limit", drop, maxBufferSize/(1024*1024)) + } +} diff --git a/client/internal/metrics/influxdb_test.go b/client/internal/metrics/influxdb_test.go new file mode 100644 index 000000000..b964e31a3 --- /dev/null +++ b/client/internal/metrics/influxdb_test.go @@ -0,0 +1,229 @@ +package metrics + +import ( + "bytes" + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInfluxDBMetrics_RecordAndExport(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + ts := ConnectionStageTimestamps{ + SignalingReceived: time.Now().Add(-3 * time.Second), + ConnectionReady: time.Now().Add(-2 * time.Second), + WgHandshakeSuccess: time.Now().Add(-1 * time.Second), + } + + m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_peer_connection,") + assert.Contains(t, output, "connection_to_wg_handshake_seconds=") + assert.Contains(t, output, "signaling_to_connection_seconds=") + assert.Contains(t, output, "total_seconds=") +} + +func TestInfluxDBMetrics_ExportDeterministicFieldOrder(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + ts := ConnectionStageTimestamps{ + SignalingReceived: time.Now().Add(-3 * time.Second), + ConnectionReady: time.Now().Add(-2 * time.Second), + WgHandshakeSuccess: time.Now().Add(-1 * time.Second), + } + + // Record multiple times and verify consistent field order + for i := 0; i < 10; i++ { + m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts) + } + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.Len(t, lines, 10) + + // Extract field portion from each line and verify they're all identical + var fieldSections []string + for _, line := range lines { + parts := strings.SplitN(line, " ", 3) + require.Len(t, parts, 3, "each line should have measurement, fields, timestamp") + fieldSections = append(fieldSections, parts[1]) + } + + for i := 1; i < len(fieldSections); i++ { + assert.Equal(t, fieldSections[0], fieldSections[i], "field order should be deterministic across samples") + } + + // Fields should be alphabetically sorted + assert.True(t, strings.HasPrefix(fieldSections[0], "connection_to_wg_handshake_seconds="), + "fields should be sorted: connection_to_wg < signaling_to < total") +} + +func TestInfluxDBMetrics_RecordSyncDuration(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeSelfHosted, + Version: "2.0.0", + OS: "darwin", + Arch: "arm64", + peerID: "def456", + } + + m.RecordSyncDuration(context.Background(), agentInfo, 1500*time.Millisecond) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_sync,") + assert.Contains(t, output, "duration_seconds=1.5") + assert.Contains(t, output, "deployment_type=selfhosted") +} + +func TestInfluxDBMetrics_Reset(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + m.RecordSyncDuration(context.Background(), agentInfo, time.Second) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + assert.NotEmpty(t, buf.String()) + + m.Reset() + + buf.Reset() + err = m.Export(&buf) + require.NoError(t, err) + assert.Empty(t, buf.String(), "should be empty after reset") +} + +func TestInfluxDBMetrics_ExportEmpty(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + assert.Empty(t, buf.String()) +} + +func TestInfluxDBMetrics_TrimByAge(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + m.mu.Lock() + m.samples = append(m.samples, influxSample{ + measurement: "old", + tags: "t=1", + fields: map[string]float64{"v": 1}, + timestamp: time.Now().Add(-maxSampleAge - time.Hour), + }) + m.trimLocked() + remaining := len(m.samples) + m.mu.Unlock() + + assert.Equal(t, 0, remaining, "old samples should be trimmed") +} + +func TestInfluxDBMetrics_RecordLoginDuration(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeCloud, + Version: "1.0.0", + OS: "linux", + Arch: "amd64", + peerID: "abc123", + } + + m.RecordLoginDuration(context.Background(), agentInfo, 2500*time.Millisecond, true) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_login,") + assert.Contains(t, output, "duration_seconds=2.5") + assert.Contains(t, output, "result=success") +} + +func TestInfluxDBMetrics_RecordLoginDurationFailure(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + agentInfo := AgentInfo{ + DeploymentType: DeploymentTypeSelfHosted, + Version: "1.0.0", + OS: "darwin", + Arch: "arm64", + peerID: "xyz789", + } + + m.RecordLoginDuration(context.Background(), agentInfo, 5*time.Second, false) + + var buf bytes.Buffer + err := m.Export(&buf) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "netbird_login,") + assert.Contains(t, output, "result=failure") + assert.Contains(t, output, "deployment_type=selfhosted") +} + +func TestInfluxDBMetrics_TrimBySize(t *testing.T) { + m := newInfluxDBMetrics().(*influxDBMetrics) + + maxSamples := maxBufferSize / estimatedSampleSize + m.mu.Lock() + for i := 0; i < maxSamples+100; i++ { + m.samples = append(m.samples, influxSample{ + measurement: "test", + tags: "t=1", + fields: map[string]float64{"v": float64(i)}, + timestamp: time.Now(), + }) + } + m.trimLocked() + remaining := len(m.samples) + m.mu.Unlock() + + assert.Equal(t, maxSamples, remaining, "should trim to max samples") +} diff --git a/client/internal/metrics/infra/.env.example b/client/internal/metrics/infra/.env.example new file mode 100644 index 000000000..9c5c1a258 --- /dev/null +++ b/client/internal/metrics/infra/.env.example @@ -0,0 +1,16 @@ +# Copy to .env and adjust values before running docker compose + +# InfluxDB admin (server-side only, never exposed to clients) +INFLUXDB_ADMIN_PASSWORD=changeme +INFLUXDB_ADMIN_TOKEN=changeme + +# Grafana admin credentials +GRAFANA_ADMIN_USER=admin +GRAFANA_ADMIN_PASSWORD=changeme + +# Remote config served by ingest at /config +# Set CONFIG_METRICS_SERVER_URL to the ingest server's public address to enable +CONFIG_METRICS_SERVER_URL= +CONFIG_VERSION_SINCE=0.0.0 +CONFIG_VERSION_UNTIL=99.99.99 +CONFIG_PERIOD_MINUTES=5 diff --git a/client/internal/metrics/infra/.gitignore b/client/internal/metrics/infra/.gitignore new file mode 100644 index 000000000..4c49bd78f --- /dev/null +++ b/client/internal/metrics/infra/.gitignore @@ -0,0 +1 @@ +.env diff --git a/client/internal/metrics/infra/README.md b/client/internal/metrics/infra/README.md new file mode 100644 index 000000000..5a93dbd87 --- /dev/null +++ b/client/internal/metrics/infra/README.md @@ -0,0 +1,194 @@ +# Client Metrics + +Internal documentation for the NetBird client metrics system. + +## Overview + +Client metrics track connection performance and sync durations using InfluxDB line protocol (`influxdb.go`). Each event is pushed once then cleared. + +Metrics collection is always active (for debug bundles). Push to backend is: +- Disabled by default (opt-in via `NB_METRICS_PUSH_ENABLED=true`) +- Managed at daemon layer (survives engine restarts) + +## Architecture + +### Layer Separation + +```text +Daemon Layer (connect.go) + ├─ Creates ClientMetrics instance once + ├─ Starts/stops push lifecycle + └─ Updates AgentInfo on profile switch + │ + â–¼ +Engine Layer (engine.go) + └─ Records metrics via ClientMetrics methods +``` + +### Ingest Server + +Clients do not talk to InfluxDB directly. An ingest server sits between clients and InfluxDB: + +```text +Client ──POST──▶ Ingest Server (:8087) ──▶ InfluxDB (internal) + │ + ├─ Validates line protocol + ├─ Allowlists measurements, fields, and tags + ├─ Rejects out-of-bound values + └─ Serves remote config at /config +``` + +- **No secret/token-based client auth** — the ingest server holds the InfluxDB token server-side. Clients must send a hashed peer ID via `X-Peer-ID` header. +- **InfluxDB is not exposed** — only accessible within the docker network +- Source: `ingest/main.go` + +## Metrics Collected + +### Connection Stage Timing + +Measurement: `netbird_peer_connection` + +| Field | Timestamps | Description | +|-------|-----------|-------------| +| `signaling_to_connection_seconds` | `SignalingReceived → ConnectionReady` | ICE/relay negotiation time after the first signal is received from the remote peer | +| `connection_to_wg_handshake_seconds` | `ConnectionReady → WgHandshakeSuccess` | WireGuard cryptographic handshake latency once the transport layer is ready | +| `total_seconds` | `SignalingReceived → WgHandshakeSuccess` | End-to-end connection time anchored at the first received signal | + +Tags: +- `deployment_type`: "cloud" | "selfhosted" | "unknown" +- `connection_type`: "ice" | "relay" +- `attempt_type`: "initial" | "reconnection" +- `version`: NetBird version string +- `os`: Operating system (linux, darwin, windows, android, ios, etc.) +- `arch`: CPU architecture (amd64, arm64, etc.) + +**Note:** `SignalingReceived` is set when the first offer or answer arrives from the remote peer (in both initial and reconnection paths). It excludes the potentially unbounded wait for the remote peer to come online. + +### Sync Duration + +Measurement: `netbird_sync` + +| Field | Description | +|-------|-------------| +| `duration_seconds` | Time to process a sync message from management server | + +Tags: +- `deployment_type`: "cloud" | "selfhosted" | "unknown" +- `version`: NetBird version string +- `os`: Operating system (linux, darwin, windows, android, ios, etc.) +- `arch`: CPU architecture (amd64, arm64, etc.) + +### Login Duration + +Measurement: `netbird_login` + +| Field | Description | +|-------|-------------| +| `duration_seconds` | Time to complete the login/auth exchange with management server | + +Tags: +- `deployment_type`: "cloud" | "selfhosted" | "unknown" +- `result`: "success" | "failure" +- `version`: NetBird version string +- `os`: Operating system (linux, darwin, windows, android, ios, etc.) +- `arch`: CPU architecture (amd64, arm64, etc.) + +## Buffer Limits + +The InfluxDB backend limits in-memory sample storage to prevent unbounded growth when pushes fail: +- **Max age:** Samples older than 5 days are dropped +- **Max size:** Estimated buffer size capped at 5 MB (~20k samples) + +## Configuration + +### Client Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `NB_METRICS_PUSH_ENABLED` | `false` | Enable metrics push to backend | +| `NB_METRICS_SERVER_URL` | *(from remote config)* | Ingest server URL (e.g., `https://ingest.netbird.io`) | +| `NB_METRICS_INTERVAL` | *(from remote config)* | Push interval (e.g., "1m", "30m", "4h") | +| `NB_METRICS_FORCE_SENDING` | `false` | Skip remote config, push unconditionally | +| `NB_METRICS_CONFIG_URL` | `https://ingest.netbird.io/config` | Remote push config URL | + +`NB_METRICS_SERVER_URL` and `NB_METRICS_INTERVAL` override their respective values but do not bypass remote config eligibility checks (version range). Use `NB_METRICS_FORCE_SENDING=true` to skip all remote config gating. + +### Ingest Server Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `INGEST_LISTEN_ADDR` | `:8087` | Listen address | +| `INFLUXDB_URL` | `http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns` | InfluxDB write endpoint | +| `INFLUXDB_TOKEN` | *(required)* | InfluxDB auth token (server-side only) | +| `CONFIG_METRICS_SERVER_URL` | *(empty — disables /config)* | `server_url` in the remote config JSON (the URL clients push metrics to) | +| `CONFIG_VERSION_SINCE` | `0.0.0` | Minimum client version to push metrics | +| `CONFIG_VERSION_UNTIL` | `99.99.99` | Maximum client version to push metrics | +| `CONFIG_PERIOD_MINUTES` | `5` | Push interval in minutes | + +The ingest server serves a remote config JSON at `GET /config` when `CONFIG_METRICS_SERVER_URL` is set. Clients can use `NB_METRICS_CONFIG_URL=http:///config` to fetch it. + +### Configuration Precedence + +For URL and Interval, the precedence is: +1. **Environment variable** - `NB_METRICS_SERVER_URL` / `NB_METRICS_INTERVAL` +2. **Remote config** - fetched from `NB_METRICS_CONFIG_URL` +3. **Default** - 5 minute interval, URL from remote config + +## Push Behavior + +1. `StartPush()` spawns background goroutine with timer +2. First push happens immediately on startup +3. Periodically: `push()` → `Export()` → HTTP POST to ingest server +4. On failure: log error, continue (non-blocking) +5. On success: `Reset()` clears pushed samples +6. `StopPush()` cancels context and waits for goroutine + +Samples are collected with exact timestamps, pushed once, then cleared. No data is resent. + +## Local Development Setup + +### 1. Configure and Start Services + +```bash +# From this directory (client/internal/metrics/infra) +cp .env.example .env +# Edit .env to set INFLUXDB_ADMIN_PASSWORD, INFLUXDB_ADMIN_TOKEN, and GRAFANA_ADMIN_PASSWORD +docker compose up -d +``` + +This starts: +- **Ingest server** on http://localhost:8087 — accepts client metrics (requires `X-Peer-ID` header, no secret/token auth) +- **InfluxDB** — internal only, not exposed to host +- **Grafana** on http://localhost:3001 + +### 2. Configure Client + +```bash +export NB_METRICS_PUSH_ENABLED=true +export NB_METRICS_FORCE_SENDING=true +export NB_METRICS_SERVER_URL=http://localhost:8087 +export NB_METRICS_INTERVAL=1m +``` + +### 3. Run Client + +```bash +cd ../../../.. +go run ./client/ up +``` + +### 4. View in Grafana + +- **InfluxDB dashboard:** http://localhost:3001/d/netbird-influxdb-metrics + +### 5. Verify Data + +```bash +# Query via InfluxDB (using admin token from .env) +docker compose exec influxdb influx query \ + 'from(bucket: "metrics") |> range(start: -1h)' \ + --org netbird + +# Check ingest server health +curl http://localhost:8087/health +``` \ No newline at end of file diff --git a/client/internal/metrics/infra/docker-compose.yml b/client/internal/metrics/infra/docker-compose.yml new file mode 100644 index 000000000..0f2b6b889 --- /dev/null +++ b/client/internal/metrics/infra/docker-compose.yml @@ -0,0 +1,69 @@ +version: '3.8' + +services: + ingest: + container_name: ingest + build: + context: ./ingest + ports: + - "8087:8087" + environment: + - INGEST_LISTEN_ADDR=:8087 + - INFLUXDB_URL=http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns + - INFLUXDB_TOKEN=${INFLUXDB_ADMIN_TOKEN:?required} + - CONFIG_METRICS_SERVER_URL=${CONFIG_METRICS_SERVER_URL:-} + - CONFIG_VERSION_SINCE=${CONFIG_VERSION_SINCE:-0.0.0} + - CONFIG_VERSION_UNTIL=${CONFIG_VERSION_UNTIL:-99.99.99} + - CONFIG_PERIOD_MINUTES=${CONFIG_PERIOD_MINUTES:-5} + depends_on: + - influxdb + restart: unless-stopped + networks: + - metrics + + influxdb: + container_name: influxdb + image: influxdb:2 + # No ports exposed — only accessible within the metrics network + volumes: + - influxdb-data:/var/lib/influxdb2 + - ./influxdb/scripts:/docker-entrypoint-initdb.d + environment: + - DOCKER_INFLUXDB_INIT_MODE=setup + - DOCKER_INFLUXDB_INIT_USERNAME=admin + - DOCKER_INFLUXDB_INIT_PASSWORD=${INFLUXDB_ADMIN_PASSWORD:?required} + - DOCKER_INFLUXDB_INIT_ORG=netbird + - DOCKER_INFLUXDB_INIT_BUCKET=metrics + - DOCKER_INFLUXDB_INIT_RETENTION=365d + - DOCKER_INFLUXDB_INIT_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-} + restart: unless-stopped + networks: + - metrics + + grafana: + container_name: grafana + image: grafana/grafana:11.6.0 + ports: + - "3001:3000" + environment: + - GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER:-admin} + - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:?required} + - GF_USERS_ALLOW_SIGN_UP=false + - GF_INSTALL_PLUGINS= + - INFLUXDB_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-} + volumes: + - grafana-data:/var/lib/grafana + - ./grafana/provisioning:/etc/grafana/provisioning + depends_on: + - influxdb + restart: unless-stopped + networks: + - metrics + +volumes: + influxdb-data: + grafana-data: + +networks: + metrics: + driver: bridge diff --git a/client/internal/metrics/infra/grafana/provisioning/dashboards/dashboard.yml b/client/internal/metrics/infra/grafana/provisioning/dashboards/dashboard.yml new file mode 100644 index 000000000..a7e8d3989 --- /dev/null +++ b/client/internal/metrics/infra/grafana/provisioning/dashboards/dashboard.yml @@ -0,0 +1,12 @@ +apiVersion: 1 + +providers: + - name: 'NetBird Dashboards' + orgId: 1 + folder: '' + type: file + disableDeletion: false + updateIntervalSeconds: 10 + allowUiUpdates: true + options: + path: /etc/grafana/provisioning/dashboards/json \ No newline at end of file diff --git a/client/internal/metrics/infra/grafana/provisioning/dashboards/json/netbird-influxdb-metrics.json b/client/internal/metrics/infra/grafana/provisioning/dashboards/json/netbird-influxdb-metrics.json new file mode 100644 index 000000000..2bcc9cbab --- /dev/null +++ b/client/internal/metrics/infra/grafana/provisioning/dashboards/json/netbird-influxdb-metrics.json @@ -0,0 +1,280 @@ +{ + "uid": "netbird-influxdb-metrics", + "title": "NetBird Client Metrics (InfluxDB)", + "tags": ["netbird", "connections", "influxdb"], + "timezone": "browser", + "panels": [ + { + "id": 5, + "title": "Sync Duration Extremes", + "type": "stat", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")", + "refId": "A" + }, + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0 + } + }, + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "colorMode": "value", + "graphMode": "none", + "textMode": "auto" + } + }, + { + "id": 6, + "title": "Total Connection Time Extremes", + "type": "stat", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")", + "refId": "A" + }, + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0 + } + }, + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "colorMode": "value", + "graphMode": "none", + "textMode": "auto" + } + }, + { + "id": 1, + "title": "Sync Duration", + "type": "timeseries", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Sync Duration\")", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0, + "custom": { + "drawStyle": "points", + "pointSize": 5 + } + } + } + }, + { + "id": 4, + "title": "ICE vs Relay", + "type": "piechart", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> drop(columns: [\"deployment_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"connection_pair_id\"])\n |> last()\n |> group(columns: [\"connection_type\"])\n |> count()", + "refId": "A" + } + ], + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "pieType": "donut", + "tooltip": { + "mode": "multi" + } + } + }, + { + "id": 2, + "title": "Connection Stage Durations (avg)", + "type": "bargauge", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"signaling_to_connection_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Signaling to Connection\"})", + "refId": "A" + }, + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"connection_to_wg_handshake_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Connection to WG Handshake\"})", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0 + } + }, + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "orientation": "horizontal", + "displayMode": "gradient" + } + }, + { + "id": 3, + "title": "Total Connection Time", + "type": "timeseries", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> set(key: \"_field\", value: \"Total Connection Time\")", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0, + "custom": { + "drawStyle": "points", + "pointSize": 5 + } + } + } + }, + { + "id": 7, + "title": "Login Duration", + "type": "timeseries", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Login Duration\")", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ms", + "min": 0, + "custom": { + "drawStyle": "points", + "pointSize": 5 + } + } + } + }, + { + "id": 8, + "title": "Login Success vs Failure", + "type": "piechart", + "datasource": { + "type": "influxdb", + "uid": "influxdb" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "targets": [ + { + "query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"result\"])\n |> count()", + "refId": "A" + } + ], + "options": { + "reduceOptions": { + "calcs": ["lastNotNull"] + }, + "pieType": "donut", + "tooltip": { + "mode": "multi" + } + } + } + ], + "schemaVersion": 27, + "version": 2, + "refresh": "30s" +} diff --git a/client/internal/metrics/infra/grafana/provisioning/datasources/influxdb.yml b/client/internal/metrics/infra/grafana/provisioning/datasources/influxdb.yml new file mode 100644 index 000000000..69b96a93a --- /dev/null +++ b/client/internal/metrics/infra/grafana/provisioning/datasources/influxdb.yml @@ -0,0 +1,15 @@ +apiVersion: 1 + +datasources: + - name: InfluxDB + uid: influxdb + type: influxdb + access: proxy + url: http://influxdb:8086 + editable: true + jsonData: + version: Flux + organization: netbird + defaultBucket: metrics + secureJsonData: + token: ${INFLUXDB_ADMIN_TOKEN} \ No newline at end of file diff --git a/client/internal/metrics/infra/influxdb/scripts/create-tokens.sh b/client/internal/metrics/infra/influxdb/scripts/create-tokens.sh new file mode 100755 index 000000000..2464803e8 --- /dev/null +++ b/client/internal/metrics/infra/influxdb/scripts/create-tokens.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Creates a scoped InfluxDB read-only token for Grafana. +# Clients do not need a token — they push via the ingest server. + +BUCKET_ID=$(influx bucket list --org netbird --name metrics --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1) +ORG_ID=$(influx org list --name netbird --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1) + +if [[ -z "$BUCKET_ID" ]] || [[ -z "$ORG_ID" ]]; then + echo "ERROR: Could not determine bucket or org ID" >&2 + echo "BUCKET_ID=$BUCKET_ID ORG_ID=$ORG_ID" >&2 + exit 1 +fi + +# Create read-only token for Grafana +READ_TOKEN=$(influx auth create \ + --org netbird \ + --read-bucket "$BUCKET_ID" \ + --description "Grafana read-only token" \ + --json | grep -oP '"token"\s*:\s*"\K[^"]+' | head -1) + +echo "" +echo "============================================" +echo "GRAFANA READ-ONLY TOKEN:" +echo "$READ_TOKEN" +echo "============================================" \ No newline at end of file diff --git a/client/internal/metrics/infra/ingest/Dockerfile b/client/internal/metrics/infra/ingest/Dockerfile new file mode 100644 index 000000000..3620c524b --- /dev/null +++ b/client/internal/metrics/infra/ingest/Dockerfile @@ -0,0 +1,10 @@ +FROM golang:1.25-alpine AS build +WORKDIR /app +COPY go.mod main.go ./ +RUN CGO_ENABLED=0 go build -o ingest . + +FROM alpine:3.20 +RUN adduser -D -H ingest +COPY --from=build /app/ingest /usr/local/bin/ingest +USER ingest +ENTRYPOINT ["ingest"] \ No newline at end of file diff --git a/client/internal/metrics/infra/ingest/go.mod b/client/internal/metrics/infra/ingest/go.mod new file mode 100644 index 000000000..aaf1ea9da --- /dev/null +++ b/client/internal/metrics/infra/ingest/go.mod @@ -0,0 +1,11 @@ +module github.com/netbirdio/netbird/client/internal/metrics/infra/ingest + +go 1.25 + +require github.com/stretchr/testify v1.11.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/client/internal/metrics/infra/ingest/go.sum b/client/internal/metrics/infra/ingest/go.sum new file mode 100644 index 000000000..c4c1710c4 --- /dev/null +++ b/client/internal/metrics/infra/ingest/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/client/internal/metrics/infra/ingest/main.go b/client/internal/metrics/infra/ingest/main.go new file mode 100644 index 000000000..a5031a873 --- /dev/null +++ b/client/internal/metrics/infra/ingest/main.go @@ -0,0 +1,355 @@ +package main + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +const ( + defaultListenAddr = ":8087" + defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns" + maxBodySize = 50 * 1024 * 1024 // 50 MB max request body + maxDurationSeconds = 300.0 // reject any duration field > 5 minutes + peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars + maxTagValueLength = 64 // reject tag values longer than this +) + +type measurementSpec struct { + allowedFields map[string]bool + allowedTags map[string]bool +} + +var allowedMeasurements = map[string]measurementSpec{ + "netbird_peer_connection": { + allowedFields: map[string]bool{ + "signaling_to_connection_seconds": true, + "connection_to_wg_handshake_seconds": true, + "total_seconds": true, + }, + allowedTags: map[string]bool{ + "deployment_type": true, + "connection_type": true, + "attempt_type": true, + "version": true, + "os": true, + "arch": true, + "peer_id": true, + "connection_pair_id": true, + }, + }, + "netbird_sync": { + allowedFields: map[string]bool{ + "duration_seconds": true, + }, + allowedTags: map[string]bool{ + "deployment_type": true, + "version": true, + "os": true, + "arch": true, + "peer_id": true, + }, + }, + "netbird_login": { + allowedFields: map[string]bool{ + "duration_seconds": true, + }, + allowedTags: map[string]bool{ + "deployment_type": true, + "result": true, + "version": true, + "os": true, + "arch": true, + "peer_id": true, + }, + }, +} + +func main() { + listenAddr := envOr("INGEST_LISTEN_ADDR", defaultListenAddr) + influxURL := envOr("INFLUXDB_URL", defaultInfluxDBURL) + influxToken := os.Getenv("INFLUXDB_TOKEN") + + if influxToken == "" { + log.Fatal("INFLUXDB_TOKEN is required") + } + + client := &http.Client{Timeout: 10 * time.Second} + + http.HandleFunc("/", handleIngest(client, influxURL, influxToken)) + + // Build config JSON once at startup from env vars + configJSON := buildConfigJSON() + if configJSON != nil { + log.Printf("serving remote config at /config") + } + + http.HandleFunc("/config", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if configJSON == nil { + http.Error(w, "config not configured", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(configJSON) //nolint:errcheck + }) + + http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok") //nolint:errcheck + }) + + log.Printf("ingest server listening on %s, forwarding to %s", listenAddr, influxURL) + if err := http.ListenAndServe(listenAddr, nil); err != nil { //nolint:gosec + log.Fatal(err) + } +} + +func handleIngest(client *http.Client, influxURL, influxToken string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := validateAuth(r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + body, err := readBody(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if len(body) > maxBodySize { + http.Error(w, "body too large", http.StatusRequestEntityTooLarge) + return + } + + validated, err := validateLineProtocol(body) + if err != nil { + log.Printf("WARN validation failed from %s: %v", r.RemoteAddr, err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + forwardToInflux(w, r, client, influxURL, influxToken, validated) + } +} + +func forwardToInflux(w http.ResponseWriter, r *http.Request, client *http.Client, influxURL, influxToken string, body []byte) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, influxURL, bytes.NewReader(body)) + if err != nil { + log.Printf("ERROR create request: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + req.Header.Set("Content-Type", "text/plain; charset=utf-8") + req.Header.Set("Authorization", "Token "+influxToken) + + resp, err := client.Do(req) + if err != nil { + log.Printf("ERROR forward to influxdb: %v", err) + http.Error(w, "upstream error", http.StatusBadGateway) + return + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) //nolint:errcheck +} + +// validateAuth checks that the X-Peer-ID header contains a valid hashed peer ID. +func validateAuth(r *http.Request) error { + peerID := r.Header.Get("X-Peer-ID") + if peerID == "" { + return fmt.Errorf("missing X-Peer-ID header") + } + if len(peerID) != peerIDLength { + return fmt.Errorf("invalid X-Peer-ID header length") + } + if _, err := hex.DecodeString(peerID); err != nil { + return fmt.Errorf("invalid X-Peer-ID header format") + } + return nil +} + +// readBody reads the request body, decompressing gzip if Content-Encoding indicates it. +func readBody(r *http.Request) ([]byte, error) { + reader := io.LimitReader(r.Body, maxBodySize+1) + + if r.Header.Get("Content-Encoding") == "gzip" { + gz, err := gzip.NewReader(reader) + if err != nil { + return nil, fmt.Errorf("invalid gzip: %w", err) + } + defer gz.Close() + reader = io.LimitReader(gz, maxBodySize+1) + } + + return io.ReadAll(reader) +} + +// validateLineProtocol parses InfluxDB line protocol lines, +// whitelists measurements and fields, and checks value bounds. +func validateLineProtocol(body []byte) ([]byte, error) { + lines := strings.Split(strings.TrimSpace(string(body)), "\n") + var valid []string + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + if err := validateLine(line); err != nil { + return nil, err + } + + valid = append(valid, line) + } + + if len(valid) == 0 { + return nil, fmt.Errorf("no valid lines") + } + + return []byte(strings.Join(valid, "\n") + "\n"), nil +} + +func validateLine(line string) error { + // line protocol: measurement,tag=val,tag=val field=val,field=val timestamp + parts := strings.SplitN(line, " ", 3) + if len(parts) < 2 { + return fmt.Errorf("invalid line protocol: %q", truncate(line, 100)) + } + + // parts[0] is "measurement,tag=val,tag=val" + measurementAndTags := strings.Split(parts[0], ",") + measurement := measurementAndTags[0] + + spec, ok := allowedMeasurements[measurement] + if !ok { + return fmt.Errorf("unknown measurement: %q", measurement) + } + + // Validate tags (everything after measurement name in parts[0]) + for _, tagPair := range measurementAndTags[1:] { + if err := validateTag(tagPair, measurement, spec.allowedTags); err != nil { + return err + } + } + + // Validate fields + for _, pair := range strings.Split(parts[1], ",") { + if err := validateField(pair, measurement, spec.allowedFields); err != nil { + return err + } + } + + return nil +} + +func validateTag(pair, measurement string, allowedTags map[string]bool) error { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("invalid tag: %q", pair) + } + + tagName := kv[0] + if !allowedTags[tagName] { + return fmt.Errorf("unknown tag %q in measurement %q", tagName, measurement) + } + + if len(kv[1]) > maxTagValueLength { + return fmt.Errorf("tag value too long for %q: %d > %d", tagName, len(kv[1]), maxTagValueLength) + } + + return nil +} + +func validateField(pair, measurement string, allowedFields map[string]bool) error { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("invalid field: %q", pair) + } + + fieldName := kv[0] + if !allowedFields[fieldName] { + return fmt.Errorf("unknown field %q in measurement %q", fieldName, measurement) + } + + val, err := strconv.ParseFloat(kv[1], 64) + if err != nil { + return fmt.Errorf("invalid field value %q for %q", kv[1], fieldName) + } + if val < 0 { + return fmt.Errorf("negative value for %q: %g", fieldName, val) + } + if strings.HasSuffix(fieldName, "_seconds") && val > maxDurationSeconds { + return fmt.Errorf("%q too large: %g > %g", fieldName, val, maxDurationSeconds) + } + + return nil +} + +// buildConfigJSON builds the remote config JSON from env vars. +// Returns nil if required vars are not set. +func buildConfigJSON() []byte { + serverURL := os.Getenv("CONFIG_METRICS_SERVER_URL") + versionSince := envOr("CONFIG_VERSION_SINCE", "0.0.0") + versionUntil := envOr("CONFIG_VERSION_UNTIL", "99.99.99") + periodMinutes := envOr("CONFIG_PERIOD_MINUTES", "5") + + if serverURL == "" { + return nil + } + + period, err := strconv.Atoi(periodMinutes) + if err != nil || period <= 0 { + log.Printf("WARN invalid CONFIG_PERIOD_MINUTES: %q, using 5", periodMinutes) + period = 5 + } + + cfg := map[string]any{ + "server_url": serverURL, + "version-since": versionSince, + "version-until": versionUntil, + "period_minutes": period, + } + + data, err := json.Marshal(cfg) + if err != nil { + log.Printf("ERROR failed to marshal config: %v", err) + return nil + } + return data +} + +func envOr(key, defaultVal string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultVal +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} diff --git a/client/internal/metrics/infra/ingest/main_test.go b/client/internal/metrics/infra/ingest/main_test.go new file mode 100644 index 000000000..bacaa4588 --- /dev/null +++ b/client/internal/metrics/infra/ingest/main_test.go @@ -0,0 +1,124 @@ +package main + +import ( + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateLine_ValidPeerConnection(t *testing.T) { + line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789,connection_pair_id=pair1234 signaling_to_connection_seconds=1.5,connection_to_wg_handshake_seconds=0.5,total_seconds=2 1234567890` + assert.NoError(t, validateLine(line)) +} + +func TestValidateLine_ValidSync(t *testing.T) { + line := `netbird_sync,deployment_type=selfhosted,version=2.0.0,os=darwin,arch=arm64,peer_id=abcdef0123456789 duration_seconds=1.5 1234567890` + assert.NoError(t, validateLine(line)) +} + +func TestValidateLine_ValidLogin(t *testing.T) { + line := `netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789 duration_seconds=3.2 1234567890` + assert.NoError(t, validateLine(line)) +} + +func TestValidateLine_UnknownMeasurement(t *testing.T) { + line := `unknown_metric,foo=bar value=1 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown measurement") +} + +func TestValidateLine_UnknownTag(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,evil_tag=injected,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown tag") +} + +func TestValidateLine_UnknownField(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc injected_field=1 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown field") +} + +func TestValidateLine_NegativeValue(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=-1.5 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "negative") +} + +func TestValidateLine_DurationTooLarge(t *testing.T) { + line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "too large") +} + +func TestValidateLine_TotalSecondsTooLarge(t *testing.T) { + line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "too large") +} + +func TestValidateLine_TagValueTooLong(t *testing.T) { + longTag := strings.Repeat("a", maxTagValueLength+1) + line := `netbird_sync,deployment_type=` + longTag + `,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890` + err := validateLine(line) + require.Error(t, err) + assert.Contains(t, err.Error(), "tag value too long") +} + +func TestValidateLineProtocol_MultipleLines(t *testing.T) { + body := []byte( + "netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" + + "netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=2.0 1234567890\n", + ) + validated, err := validateLineProtocol(body) + require.NoError(t, err) + assert.Contains(t, string(validated), "netbird_sync") + assert.Contains(t, string(validated), "netbird_login") +} + +func TestValidateLineProtocol_RejectsOnBadLine(t *testing.T) { + body := []byte( + "netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" + + "evil_metric,foo=bar value=1 1234567890\n", + ) + _, err := validateLineProtocol(body) + require.Error(t, err) +} + +func TestValidateAuth(t *testing.T) { + tests := []struct { + name string + peerID string + wantErr bool + }{ + {"valid hex", "abcdef0123456789", false}, + {"empty", "", true}, + {"too short", "abcdef01234567", true}, + {"too long", "abcdef01234567890", true}, + {"invalid hex", "ghijklmnopqrstuv", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/", nil) + if tt.peerID != "" { + r.Header.Set("X-Peer-ID", tt.peerID) + } + err := validateAuth(r) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/client/internal/metrics/metrics.go b/client/internal/metrics/metrics.go new file mode 100644 index 000000000..4ebb43496 --- /dev/null +++ b/client/internal/metrics/metrics.go @@ -0,0 +1,224 @@ +package metrics + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/metrics/remoteconfig" +) + +// AgentInfo holds static information about the agent +type AgentInfo struct { + DeploymentType DeploymentType + Version string + OS string // runtime.GOOS (linux, darwin, windows, etc.) + Arch string // runtime.GOARCH (amd64, arm64, etc.) + peerID string // anonymised peer identifier (SHA-256 of WireGuard public key) +} + +// peerIDFromPublicKey returns a truncated SHA-256 hash (8 bytes / 16 hex chars) of the given WireGuard public key. +func peerIDFromPublicKey(pubKey string) string { + hash := sha256.Sum256([]byte(pubKey)) + return hex.EncodeToString(hash[:8]) +} + +// connectionPairID returns a deterministic identifier for a connection between two peers. +// It sorts the two peer IDs before hashing so the same pair always produces the same ID +// regardless of which side computes it. +func connectionPairID(peerID1, peerID2 string) string { + a, b := peerID1, peerID2 + if a > b { + a, b = b, a + } + hash := sha256.Sum256([]byte(a + b)) + return hex.EncodeToString(hash[:8]) +} + +// metricsImplementation defines the internal interface for metrics implementations +type metricsImplementation interface { + // RecordConnectionStages records connection stage metrics from timestamps + RecordConnectionStages( + ctx context.Context, + agentInfo AgentInfo, + connectionPairID string, + connectionType ConnectionType, + isReconnection bool, + timestamps ConnectionStageTimestamps, + ) + + // RecordSyncDuration records how long it took to process a sync message + RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration) + + // RecordLoginDuration records how long the login to management took + RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool) + + // Export exports metrics in InfluxDB line protocol format + Export(w io.Writer) error + + // Reset clears all collected metrics + Reset() +} + +type ClientMetrics struct { + impl metricsImplementation + + agentInfo AgentInfo + mu sync.RWMutex + + push *Push + pushMu sync.Mutex + wg sync.WaitGroup + pushCancel context.CancelFunc +} + +// ConnectionStageTimestamps holds timestamps for each connection stage +type ConnectionStageTimestamps struct { + SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection) + ConnectionReady time.Time + WgHandshakeSuccess time.Time +} + +// String returns a human-readable representation of the connection stage timestamps +func (c ConnectionStageTimestamps) String() string { + return fmt.Sprintf("ConnectionStageTimestamps{SignalingReceived=%v, ConnectionReady=%v, WgHandshakeSuccess=%v}", + c.SignalingReceived.Format(time.RFC3339Nano), + c.ConnectionReady.Format(time.RFC3339Nano), + c.WgHandshakeSuccess.Format(time.RFC3339Nano), + ) +} + +// RecordConnectionStages calculates stage durations from timestamps and records them. +// remotePubKey is the remote peer's WireGuard public key; it will be hashed for anonymisation. +func (c *ClientMetrics) RecordConnectionStages( + ctx context.Context, + remotePubKey string, + connectionType ConnectionType, + isReconnection bool, + timestamps ConnectionStageTimestamps, +) { + if c == nil { + return + } + c.mu.RLock() + agentInfo := c.agentInfo + c.mu.RUnlock() + + remotePeerID := peerIDFromPublicKey(remotePubKey) + pairID := connectionPairID(agentInfo.peerID, remotePeerID) + c.impl.RecordConnectionStages(ctx, agentInfo, pairID, connectionType, isReconnection, timestamps) +} + +// RecordSyncDuration records the duration of sync message processing +func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Duration) { + if c == nil { + return + } + c.mu.RLock() + agentInfo := c.agentInfo + c.mu.RUnlock() + + c.impl.RecordSyncDuration(ctx, agentInfo, duration) +} + +// RecordLoginDuration records how long the login to management server took +func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) { + if c == nil { + return + } + c.mu.RLock() + agentInfo := c.agentInfo + c.mu.RUnlock() + + c.impl.RecordLoginDuration(ctx, agentInfo, duration, success) +} + +// UpdateAgentInfo updates the agent information (e.g., when switching profiles). +// publicKey is the WireGuard public key; it will be hashed for anonymisation. +func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) { + if c == nil { + return + } + + agentInfo.peerID = peerIDFromPublicKey(publicKey) + + c.mu.Lock() + c.agentInfo = agentInfo + c.mu.Unlock() + + c.pushMu.Lock() + push := c.push + c.pushMu.Unlock() + if push != nil { + push.SetPeerID(agentInfo.peerID) + } +} + +// Export exports metrics to the writer +func (c *ClientMetrics) Export(w io.Writer) error { + if c == nil { + return nil + } + + return c.impl.Export(w) +} + +// StartPush starts periodic pushing of metrics with the given configuration +// Precedence: PushConfig.ServerAddress > remote config server_url +func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) { + if c == nil { + return + } + + c.pushMu.Lock() + defer c.pushMu.Unlock() + + if c.push != nil { + log.Warnf("metrics push already running") + return + } + + c.mu.RLock() + agentVersion := c.agentInfo.Version + peerID := c.agentInfo.peerID + c.mu.RUnlock() + + configManager := remoteconfig.NewManager(getMetricsConfigURL(), remoteconfig.DefaultMinRefreshInterval) + push, err := NewPush(c.impl, configManager, config, agentVersion) + if err != nil { + log.Errorf("failed to create metrics push: %v", err) + return + } + push.SetPeerID(peerID) + + ctx, cancel := context.WithCancel(ctx) + c.pushCancel = cancel + + c.wg.Add(1) + go func() { + defer c.wg.Done() + push.Start(ctx) + }() + c.push = push +} + +func (c *ClientMetrics) StopPush() { + if c == nil { + return + } + c.pushMu.Lock() + defer c.pushMu.Unlock() + if c.push == nil { + return + } + + c.pushCancel() + c.wg.Wait() + c.push = nil +} diff --git a/client/internal/metrics/metrics_default.go b/client/internal/metrics/metrics_default.go new file mode 100644 index 000000000..927ab51d1 --- /dev/null +++ b/client/internal/metrics/metrics_default.go @@ -0,0 +1,11 @@ +//go:build !js + +package metrics + +// NewClientMetrics creates a new ClientMetrics instance +func NewClientMetrics(agentInfo AgentInfo) *ClientMetrics { + return &ClientMetrics{ + impl: newInfluxDBMetrics(), + agentInfo: agentInfo, + } +} diff --git a/client/internal/metrics/metrics_js.go b/client/internal/metrics/metrics_js.go new file mode 100644 index 000000000..dfa6d8243 --- /dev/null +++ b/client/internal/metrics/metrics_js.go @@ -0,0 +1,8 @@ +//go:build js + +package metrics + +// NewClientMetrics returns nil on WASM builds — all ClientMetrics methods are nil-safe. +func NewClientMetrics(AgentInfo) *ClientMetrics { + return nil +} diff --git a/client/internal/metrics/push.go b/client/internal/metrics/push.go new file mode 100644 index 000000000..ee0508f36 --- /dev/null +++ b/client/internal/metrics/push.go @@ -0,0 +1,289 @@ +package metrics + +import ( + "bytes" + "compress/gzip" + "context" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/metrics/remoteconfig" +) + +const ( + // defaultPushInterval is the default interval for pushing metrics + defaultPushInterval = 5 * time.Minute +) + +// defaultMetricsServerURL is used as fallback when NB_METRICS_FORCE_SENDING is true +var defaultMetricsServerURL *url.URL + +func init() { + defaultMetricsServerURL, _ = url.Parse("https://ingest.netbird.io") +} + +// PushConfig holds configuration for metrics push +type PushConfig struct { + // ServerAddress is the metrics server URL. If nil, uses remote config server_url. + ServerAddress *url.URL + // Interval is how often to push metrics. If 0, uses remote config interval or defaultPushInterval. + Interval time.Duration + // ForceSending skips remote configuration fetch and version checks, pushing unconditionally. + ForceSending bool +} + +// PushConfigFromEnv builds a PushConfig from environment variables. +func PushConfigFromEnv() PushConfig { + config := PushConfig{} + + config.ForceSending = isForceSending() + config.ServerAddress = getMetricsServerURL() + config.Interval = getMetricsInterval() + + return config +} + +// remoteConfigProvider abstracts remote push config fetching for testability +type remoteConfigProvider interface { + RefreshIfNeeded(ctx context.Context) *remoteconfig.Config +} + +// Push handles periodic pushing of metrics +type Push struct { + metrics metricsImplementation + configManager remoteConfigProvider + agentVersion *goversion.Version + + peerID string + peerMu sync.RWMutex + + client *http.Client + cfgForceSending bool + cfgInterval time.Duration + cfgAddress *url.URL +} + +// NewPush creates a new Push instance with configuration resolution +func NewPush(metrics metricsImplementation, configManager remoteConfigProvider, config PushConfig, agentVersion string) (*Push, error) { + var cfgInterval time.Duration + var cfgAddress *url.URL + + if config.ForceSending { + cfgInterval = config.Interval + if config.Interval <= 0 { + cfgInterval = defaultPushInterval + } + + cfgAddress = config.ServerAddress + if cfgAddress == nil { + cfgAddress = defaultMetricsServerURL + } + } else { + cfgAddress = config.ServerAddress + + if config.Interval < 0 { + log.Warnf("negative metrics push interval %s", config.Interval) + } else { + cfgInterval = config.Interval + } + } + + parsedVersion, err := goversion.NewVersion(agentVersion) + if err != nil { + if !config.ForceSending { + return nil, fmt.Errorf("parse agent version %q: %w", agentVersion, err) + } + } + + return &Push{ + metrics: metrics, + configManager: configManager, + agentVersion: parsedVersion, + cfgForceSending: config.ForceSending, + cfgInterval: cfgInterval, + cfgAddress: cfgAddress, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + }, nil +} + +// SetPeerID updates the hashed peer ID used for the Authorization header. +func (p *Push) SetPeerID(peerID string) { + p.peerMu.Lock() + p.peerID = peerID + p.peerMu.Unlock() +} + +// Start starts the periodic push loop. +// The env interval override controls tick frequency but does not bypass remote config +// version gating. Use ForceSending to skip remote config entirely. +func (p *Push) Start(ctx context.Context) { + // Log initial state + switch { + case p.cfgForceSending: + log.Infof("started metrics push with force sending to %s, interval %s", p.cfgAddress, p.cfgInterval) + case p.cfgAddress != nil: + log.Infof("started metrics push with server URL override: %s", p.cfgAddress.String()) + default: + log.Infof("started metrics push, server URL will be resolved from remote config") + } + + timer := time.NewTimer(0) // fire immediately on first iteration + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + log.Debug("stopping metrics push") + return + case <-timer.C: + } + + pushURL, interval := p.resolve(ctx) + if pushURL != "" { + if err := p.push(ctx, pushURL); err != nil { + log.Errorf("failed to push metrics: %v", err) + } + } + + if interval <= 0 { + interval = defaultPushInterval + } + timer.Reset(interval) + } +} + +// resolve returns the push URL and interval for the next cycle. +// Returns empty pushURL to skip this cycle. +func (p *Push) resolve(ctx context.Context) (pushURL string, interval time.Duration) { + if p.cfgForceSending { + return p.resolveServerURL(nil), p.cfgInterval + } + + config := p.configManager.RefreshIfNeeded(ctx) + if config == nil { + log.Debug("no metrics push config available, waiting to retry") + return "", defaultPushInterval + } + + // prefer env variables instead of remote config + if p.cfgInterval > 0 { + interval = p.cfgInterval + } else { + interval = config.Interval + } + + if !isVersionInRange(p.agentVersion, config.VersionSince, config.VersionUntil) { + log.Debugf("agent version %s not in range [%s, %s), skipping metrics push", + p.agentVersion, config.VersionSince, config.VersionUntil) + return "", interval + } + + pushURL = p.resolveServerURL(&config.ServerURL) + if pushURL == "" { + log.Warn("no metrics server URL available, skipping push") + } + return pushURL, interval +} + +// push exports metrics and sends them to the metrics server +func (p *Push) push(ctx context.Context, pushURL string) error { + // Export metrics without clearing + var buf bytes.Buffer + if err := p.metrics.Export(&buf); err != nil { + return fmt.Errorf("export metrics: %w", err) + } + + // Don't push if there are no metrics + if buf.Len() == 0 { + log.Tracef("no metrics to push") + return nil + } + + // Gzip compress the body + compressed, err := gzipCompress(buf.Bytes()) + if err != nil { + return fmt.Errorf("gzip compress: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", pushURL, compressed) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "text/plain; charset=utf-8") + req.Header.Set("Content-Encoding", "gzip") + + p.peerMu.RLock() + peerID := p.peerID + p.peerMu.RUnlock() + if peerID != "" { + req.Header.Set("X-Peer-ID", peerID) + } + + // Send request + resp, err := p.client.Do(req) + if err != nil { + return fmt.Errorf("send request: %w", err) + } + defer func() { + if resp.Body == nil { + return + } + if err := resp.Body.Close(); err != nil { + log.Warnf("failed to close response body: %v", err) + } + }() + + // Check response status + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("push failed with status %d", resp.StatusCode) + } + + log.Debugf("successfully pushed metrics to %s", pushURL) + p.metrics.Reset() + return nil +} + +// resolveServerURL determines the push URL. +// Precedence: envAddress (env var) > remote config server_url +func (p *Push) resolveServerURL(remoteServerURL *url.URL) string { + var baseURL *url.URL + if p.cfgAddress != nil { + baseURL = p.cfgAddress + } else { + baseURL = remoteServerURL + } + + if baseURL == nil { + return "" + } + + return baseURL.String() +} + +// gzipCompress compresses data using gzip and returns the compressed buffer. +func gzipCompress(data []byte) (*bytes.Buffer, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := gz.Write(data); err != nil { + _ = gz.Close() + return nil, err + } + if err := gz.Close(); err != nil { + return nil, err + } + return &buf, nil +} + +// isVersionInRange checks if current falls within [since, until) +func isVersionInRange(current, since, until *goversion.Version) bool { + return !current.LessThan(since) && current.LessThan(until) +} diff --git a/client/internal/metrics/push_test.go b/client/internal/metrics/push_test.go new file mode 100644 index 000000000..20a509da1 --- /dev/null +++ b/client/internal/metrics/push_test.go @@ -0,0 +1,343 @@ +package metrics + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + goversion "github.com/hashicorp/go-version" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/metrics/remoteconfig" +) + +func mustVersion(s string) *goversion.Version { + v, err := goversion.NewVersion(s) + if err != nil { + panic(err) + } + return v +} + +func mustURL(s string) url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return *u +} + +func parseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +func testConfig(serverURL, since, until string, period time.Duration) *remoteconfig.Config { + return &remoteconfig.Config{ + ServerURL: mustURL(serverURL), + VersionSince: mustVersion(since), + VersionUntil: mustVersion(until), + Interval: period, + } +} + +// mockConfigProvider implements remoteConfigProvider for testing +type mockConfigProvider struct { + config *remoteconfig.Config +} + +func (m *mockConfigProvider) RefreshIfNeeded(_ context.Context) *remoteconfig.Config { + return m.config +} + +// mockMetrics implements metricsImplementation for testing +type mockMetrics struct { + exportData string +} + +func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ string, _ ConnectionType, _ bool, _ ConnectionStageTimestamps) { +} + +func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) { +} + +func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) { +} + +func (m *mockMetrics) Export(w io.Writer) error { + if m.exportData != "" { + _, err := w.Write([]byte(m.exportData)) + return err + } + return nil +} + +func (m *mockMetrics) Reset() { +} + +func TestPush_OverrideIntervalPushes(t *testing.T) { + var pushCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pushCount.Add(1) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 50 * time.Millisecond, + ServerAddress: parseURL(server.URL), + }, "1.0.0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + push.Start(ctx) + close(done) + }() + + require.Eventually(t, func() bool { + return pushCount.Load() >= 3 + }, 2*time.Second, 10*time.Millisecond) + + cancel() + <-done +} + +func TestPush_RemoteConfigVersionInRange(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_RemoteConfigVersionOutOfRange(t *testing.T) { + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig("http://localhost", "1.0.0", "1.5.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "2.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_NoConfigReturnsDefault(t *testing.T) { + metrics := &mockMetrics{} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) + assert.Equal(t, defaultPushInterval, interval) +} + +func TestPush_OverrideIntervalRespectsVersionCheck(t *testing.T) { + metrics := &mockMetrics{} + configProvider := &mockConfigProvider{config: testConfig("http://localhost", "3.0.0", "4.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + ServerAddress: parseURL("http://localhost"), + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) // version out of range + assert.Equal(t, 30*time.Second, interval) // but uses override interval +} + +func TestPush_OverrideIntervalUsedWhenVersionInRange(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + }, "1.5.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, 30*time.Second, interval) +} + +func TestPush_NoMetricsSkipsPush(t *testing.T) { + var pushCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pushCount.Add(1) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: ""} // no metrics to export + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0") + require.NoError(t, err) + + err = push.push(context.Background(), server.URL) + assert.NoError(t, err) + assert.Equal(t, int32(0), pushCount.Load()) +} + +func TestPush_ServerURLFromRemoteConfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Contains(t, pushURL, server.URL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_ServerAddressOverridesTakePrecedenceOverRemoteConfig(t *testing.T) { + overrideServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer overrideServer.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig("http://remote-config-server", "1.0.0", "2.0.0", 1*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + ServerAddress: parseURL(overrideServer.URL), + }, "1.5.0") + require.NoError(t, err) + + pushURL, _ := push.resolve(context.Background()) + assert.Contains(t, pushURL, overrideServer.URL) + assert.NotContains(t, pushURL, "remote-config-server") +} + +func TestPush_OverrideIntervalWithoutOverrideURL_UsesRemoteConfigURL(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Contains(t, pushURL, server.URL) + assert.Equal(t, 30*time.Second, interval) +} + +func TestPush_NoConfigSkipsPush(t *testing.T) { + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{ + Interval: 30 * time.Second, + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.Empty(t, pushURL) + assert.Equal(t, defaultPushInterval, interval) // no config available, use default retry interval +} + +func TestPush_ForceSendingSkipsRemoteConfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{ + ForceSending: true, + Interval: 1 * time.Minute, + ServerAddress: parseURL(server.URL), + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, 1*time.Minute, interval) +} + +func TestPush_ForceSendingUsesDefaultInterval(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + metrics := &mockMetrics{exportData: "test_metric 1\n"} + configProvider := &mockConfigProvider{config: nil} + + push, err := NewPush(metrics, configProvider, PushConfig{ + ForceSending: true, + ServerAddress: parseURL(server.URL), + }, "1.0.0") + require.NoError(t, err) + + pushURL, interval := push.resolve(context.Background()) + assert.NotEmpty(t, pushURL) + assert.Equal(t, defaultPushInterval, interval) +} + +func TestIsVersionInRange(t *testing.T) { + tests := []struct { + name string + current string + since string + until string + expected bool + }{ + {"at lower bound inclusive", "1.2.2", "1.2.2", "1.2.3", true}, + {"in range", "1.2.2", "1.2.0", "1.3.0", true}, + {"at upper bound exclusive", "1.2.3", "1.2.2", "1.2.3", false}, + {"below range", "1.2.1", "1.2.2", "1.2.3", false}, + {"above range", "1.3.0", "1.2.2", "1.2.3", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isVersionInRange(mustVersion(tt.current), mustVersion(tt.since), mustVersion(tt.until))) + }) + } +} diff --git a/client/internal/metrics/remoteconfig/manager.go b/client/internal/metrics/remoteconfig/manager.go new file mode 100644 index 000000000..01c37891f --- /dev/null +++ b/client/internal/metrics/remoteconfig/manager.go @@ -0,0 +1,149 @@ +package remoteconfig + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sync" + "time" + + goversion "github.com/hashicorp/go-version" + log "github.com/sirupsen/logrus" +) + +const ( + DefaultMinRefreshInterval = 30 * time.Minute +) + +// Config holds the parsed remote push configuration +type Config struct { + ServerURL url.URL + VersionSince *goversion.Version + VersionUntil *goversion.Version + Interval time.Duration +} + +// rawConfig is the JSON wire format fetched from the remote server +type rawConfig struct { + ServerURL string `json:"server_url"` + VersionSince string `json:"version-since"` + VersionUntil string `json:"version-until"` + PeriodMinutes int `json:"period_minutes"` +} + +// Manager handles fetching and caching remote push configuration +type Manager struct { + configURL string + minRefreshInterval time.Duration + client *http.Client + + mu sync.Mutex + lastConfig *Config + lastFetched time.Time +} + +func NewManager(configURL string, minRefreshInterval time.Duration) *Manager { + return &Manager{ + configURL: configURL, + minRefreshInterval: minRefreshInterval, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// RefreshIfNeeded fetches new config if the cached one is stale. +// Returns the current config (possibly just fetched) or nil if unavailable. +func (m *Manager) RefreshIfNeeded(ctx context.Context) *Config { + m.mu.Lock() + defer m.mu.Unlock() + + if m.isConfigFresh() { + return m.lastConfig + } + + fetchedConfig, err := m.fetch(ctx) + m.lastFetched = time.Now() + if err != nil { + log.Warnf("failed to fetch metrics remote config: %v", err) + return m.lastConfig // return cached (may be nil) + } + + m.lastConfig = fetchedConfig + + log.Tracef("fetched metrics remote config: version-since=%s version-until=%s period=%s", + fetchedConfig.VersionSince, fetchedConfig.VersionUntil, fetchedConfig.Interval) + + return fetchedConfig +} + +func (m *Manager) isConfigFresh() bool { + if m.lastConfig == nil { + return false + } + return time.Since(m.lastFetched) < m.minRefreshInterval +} + +func (m *Manager) fetch(ctx context.Context) (*Config, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.configURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + resp, err := m.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer func() { + if resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + + var raw rawConfig + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + if raw.PeriodMinutes <= 0 { + return nil, fmt.Errorf("invalid period_minutes: %d", raw.PeriodMinutes) + } + + if raw.ServerURL == "" { + return nil, fmt.Errorf("server_url is required") + } + + serverURL, err := url.Parse(raw.ServerURL) + if err != nil { + return nil, fmt.Errorf("parse server_url %q: %w", raw.ServerURL, err) + } + + since, err := goversion.NewVersion(raw.VersionSince) + if err != nil { + return nil, fmt.Errorf("parse version-since %q: %w", raw.VersionSince, err) + } + + until, err := goversion.NewVersion(raw.VersionUntil) + if err != nil { + return nil, fmt.Errorf("parse version-until %q: %w", raw.VersionUntil, err) + } + + return &Config{ + ServerURL: *serverURL, + VersionSince: since, + VersionUntil: until, + Interval: time.Duration(raw.PeriodMinutes) * time.Minute, + }, nil +} diff --git a/client/internal/metrics/remoteconfig/manager_test.go b/client/internal/metrics/remoteconfig/manager_test.go new file mode 100644 index 000000000..68ca3b4c4 --- /dev/null +++ b/client/internal/metrics/remoteconfig/manager_test.go @@ -0,0 +1,197 @@ +package remoteconfig + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testMinRefresh = 100 * time.Millisecond + +func TestManager_FetchSuccess(t *testing.T) { + server := newConfigServer(t, rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + + require.NotNil(t, config) + assert.Equal(t, "https://ingest.example.com", config.ServerURL.String()) + assert.Equal(t, "1.0.0", config.VersionSince.String()) + assert.Equal(t, "2.0.0", config.VersionUntil.String()) + assert.Equal(t, 60*time.Minute, config.Interval) +} + +func TestManager_CachesConfig(t *testing.T) { + var fetchCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + err := json.NewEncoder(w).Encode(rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + + // First call fetches + config1 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config1) + assert.Equal(t, int32(1), fetchCount.Load()) + + // Second call uses cache (within minRefreshInterval) + config2 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config2) + assert.Equal(t, int32(1), fetchCount.Load()) + assert.Equal(t, config1, config2) +} + +func TestManager_RefetchesWhenStale(t *testing.T) { + var fetchCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + err := json.NewEncoder(w).Encode(rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + + // First fetch + mgr.RefreshIfNeeded(context.Background()) + assert.Equal(t, int32(1), fetchCount.Load()) + + // Wait for config to become stale + time.Sleep(testMinRefresh + 10*time.Millisecond) + + // Should refetch + mgr.RefreshIfNeeded(context.Background()) + assert.Equal(t, int32(2), fetchCount.Load()) +} + +func TestManager_FetchFailureReturnsNil(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + + assert.Nil(t, config) +} + +func TestManager_FetchFailureReturnsCached(t *testing.T) { + var fetchCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount.Add(1) + if fetchCount.Load() > 1 { + w.WriteHeader(http.StatusInternalServerError) + return + } + err := json.NewEncoder(w).Encode(rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + + // First call succeeds + config1 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config1) + + // Wait for config to become stale + time.Sleep(testMinRefresh + 10*time.Millisecond) + + // Second call fails but returns cached + config2 := mgr.RefreshIfNeeded(context.Background()) + require.NotNil(t, config2) + assert.Equal(t, config1, config2) +} + +func TestManager_RejectsInvalidPeriod(t *testing.T) { + tests := []struct { + name string + period int + }{ + {"zero", 0}, + {"negative", -5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := newConfigServer(t, rawConfig{ + ServerURL: "https://ingest.example.com", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: tt.period, + }) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + assert.Nil(t, config) + }) + } +} + +func TestManager_RejectsEmptyServerURL(t *testing.T) { + server := newConfigServer(t, rawConfig{ + ServerURL: "", + VersionSince: "1.0.0", + VersionUntil: "2.0.0", + PeriodMinutes: 60, + }) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + assert.Nil(t, config) +} + +func TestManager_RejectsInvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("not json")) + require.NoError(t, err) + })) + defer server.Close() + + mgr := NewManager(server.URL, testMinRefresh) + config := mgr.RefreshIfNeeded(context.Background()) + assert.Nil(t, config) +} + +func newConfigServer(t *testing.T, config rawConfig) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(config) + require.NoError(t, err) + })) +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0f213a6fb..8d1585b3f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/client/internal/metrics" "github.com/netbirdio/netbird/client/internal/peer/conntype" "github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peer/guard" @@ -27,6 +28,17 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" ) +// MetricsRecorder is an interface for recording peer connection metrics +type MetricsRecorder interface { + RecordConnectionStages( + ctx context.Context, + remotePubKey string, + connectionType metrics.ConnectionType, + isReconnection bool, + timestamps metrics.ConnectionStageTimestamps, + ) +} + type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -35,6 +47,7 @@ type ServiceDependencies struct { SrWatcher *guard.SRWatcher PeerConnDispatcher *dispatcher.ConnectionDispatcher PortForwardManager *portforward.Manager + MetricsRecorder MetricsRecorder } type WgConfig struct { @@ -118,6 +131,10 @@ type Conn struct { dumpState *stateDump endpointUpdater *EndpointUpdater + + // Connection stage timestamps for metrics + metricsRecorder MetricsRecorder + metricsStages *MetricsStages } // NewConn creates a new not opened Conn to the remote peer. @@ -144,6 +161,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { dumpState: dumpState, endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), + metricsRecorder: services.MetricsRecorder, } return conn, nil @@ -160,6 +178,9 @@ func (conn *Conn) Open(engineCtx context.Context) error { return nil } + // Allocate new metrics stages so old goroutines don't corrupt new state + conn.metricsStages = &MetricsStages{} + conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) @@ -171,7 +192,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { } conn.workerICE = workerICE - conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) + conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) if !isForceRelayed() { @@ -339,7 +360,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn if conn.currentConnPriority > priority { conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority) conn.statusICE.SetConnected() - conn.updateIceState(iceConnInfo) + conn.updateIceState(iceConnInfo, time.Now()) return } @@ -379,7 +400,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn } conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) - conn.enableWgWatcherIfNeeded() + updateTime := time.Now() + conn.enableWgWatcherIfNeeded(updateTime) presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { @@ -395,8 +417,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn conn.currentConnPriority = priority conn.statusICE.SetConnected() - conn.updateIceState(iceConnInfo) - conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) + conn.updateIceState(iceConnInfo, updateTime) + conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr, updateTime) } func (conn *Conn) onICEStateDisconnected(sessionChanged bool) { @@ -448,6 +470,10 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) { conn.disableWgWatcherIfNeeded() + if conn.currentConnPriority == conntype.None { + conn.metricsStages.Disconnected() + } + peerState := State{ PubKey: conn.config.Key, ConnStatus: conn.evalStatus(), @@ -488,7 +514,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.setRelayedProxy(wgProxy) conn.statusRelay.SetConnected() - conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, time.Now()) return } @@ -497,7 +523,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { if controller { wgProxy.Work() } - conn.enableWgWatcherIfNeeded() + updateTime := time.Now() + conn.enableWgWatcherIfNeeded(updateTime) if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) @@ -508,13 +535,16 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { if !controller { wgProxy.Work() } + + wgConfigWorkaround() + conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.currentConnPriority = conntype.Relay conn.statusRelay.SetConnected() conn.setRelayedProxy(wgProxy) - conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey, updateTime) conn.Log.Infof("start to communicate with peer via relay") - conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) + conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr, updateTime) } func (conn *Conn) onRelayDisconnected() { @@ -552,6 +582,10 @@ func (conn *Conn) handleRelayDisconnectedLocked() { conn.disableWgWatcherIfNeeded() + if conn.currentConnPriority == conntype.None { + conn.metricsStages.Disconnected() + } + peerState := State{ PubKey: conn.config.Key, ConnStatus: conn.evalStatus(), @@ -592,10 +626,10 @@ func (conn *Conn) onWGDisconnected() { } } -func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { +func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte, updateTime time.Time) { peerState := State{ PubKey: conn.config.Key, - ConnStatusUpdate: time.Now(), + ConnStatusUpdate: updateTime, ConnStatus: conn.evalStatus(), Relayed: conn.isRelayed(), RelayServerAddress: relayServerAddr, @@ -608,10 +642,10 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by } } -func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) { +func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo, updateTime time.Time) { peerState := State{ PubKey: conn.config.Key, - ConnStatusUpdate: time.Now(), + ConnStatusUpdate: updateTime, ConnStatus: conn.evalStatus(), Relayed: iceConnInfo.Relayed, LocalIceCandidateType: iceConnInfo.LocalIceCandidateType, @@ -649,11 +683,13 @@ func (conn *Conn) setStatusToDisconnected() { } } -func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) { +func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string, updateTime time.Time) { if runtime.GOOS == "ios" { runtime.GC() } + conn.metricsStages.RecordConnectionReady(updateTime) + if conn.onConnected != nil { conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr) } @@ -705,14 +741,14 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { return true } -func (conn *Conn) enableWgWatcherIfNeeded() { +func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { if !conn.wgWatcher.IsEnabled() { wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx) conn.wgWatcherCancel = wgWatcherCancel conn.wgWatcherWg.Add(1) go func() { defer conn.wgWatcherWg.Done() - conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected) + conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess) }() } } @@ -787,6 +823,41 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// onWGHandshakeSuccess is called when the first WireGuard handshake is detected +func (conn *Conn) onWGHandshakeSuccess(when time.Time) { + conn.metricsStages.RecordWGHandshakeSuccess(when) + conn.recordConnectionMetrics() +} + +// recordConnectionMetrics records connection stage timestamps as metrics +func (conn *Conn) recordConnectionMetrics() { + if conn.metricsRecorder == nil { + return + } + + // Determine connection type based on current priority + conn.mu.Lock() + priority := conn.currentConnPriority + conn.mu.Unlock() + + var connType metrics.ConnectionType + switch priority { + case conntype.Relay: + connType = metrics.ConnectionTypeRelay + default: + connType = metrics.ConnectionTypeICE + } + + // Record metrics with timestamps - duration calculation happens in metrics package + conn.metricsRecorder.RecordConnectionStages( + context.Background(), + conn.config.Key, + connType, + conn.metricsStages.IsReconnection(), + conn.metricsStages.GetTimestamps(), + ) +} + // AllowedIP returns the allowed IP of the remote peer func (conn *Conn) AllowedIP() netip.Addr { return conn.config.WgConfig.AllowedIps[0].Addr() diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index aff26f847..9b50cecd1 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -44,12 +44,13 @@ type OfferAnswer struct { } type Handshaker struct { - mu sync.Mutex - log *log.Entry - config ConnConfig - signaler *Signaler - ice *WorkerICE - relay *WorkerRelay + mu sync.Mutex + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + metricsStages *MetricsStages // relayListener is not blocking because the listener is using a goroutine to process the messages // and it will only keep the latest message if multiple offers are received in a short time // this is to avoid blocking the handshaker if the listener is doing some heavy processing @@ -64,13 +65,14 @@ type Handshaker struct { remoteAnswerCh chan OfferAnswer } -func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker { +func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { return &Handshaker{ log: log, config: config, signaler: signaler, ice: ice, relay: relay, + metricsStages: metricsStages, remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } @@ -89,6 +91,12 @@ func (h *Handshaker) Listen(ctx context.Context) { select { case remoteOfferAnswer := <-h.remoteOffersCh: h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + + // Record signaling received for reconnection attempts + if h.metricsStages != nil { + h.metricsStages.RecordSignalingReceived() + } + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } @@ -103,6 +111,12 @@ func (h *Handshaker) Listen(ctx context.Context) { } case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + + // Record signaling received for reconnection attempts + if h.metricsStages != nil { + h.metricsStages.RecordSignalingReceived() + } + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } diff --git a/client/internal/peer/metrics_saver.go b/client/internal/peer/metrics_saver.go new file mode 100644 index 000000000..e32afbfe5 --- /dev/null +++ b/client/internal/peer/metrics_saver.go @@ -0,0 +1,73 @@ +package peer + +import ( + "sync" + "time" + + "github.com/netbirdio/netbird/client/internal/metrics" +) + +type MetricsStages struct { + isReconnectionAttempt bool // Track if current attempt is a reconnection + stageTimestamps metrics.ConnectionStageTimestamps + mu sync.Mutex +} + +// RecordSignalingReceived records when the first signal is received from the remote peer. +// Used as the base for all subsequent stage durations to avoid inflating metrics when +// the remote peer was offline. +func (s *MetricsStages) RecordSignalingReceived() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.stageTimestamps.SignalingReceived.IsZero() { + s.stageTimestamps.SignalingReceived = time.Now() + } +} + +func (s *MetricsStages) RecordConnectionReady(when time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + if s.stageTimestamps.ConnectionReady.IsZero() { + s.stageTimestamps.ConnectionReady = when + } +} + +func (s *MetricsStages) RecordWGHandshakeSuccess(handshakeTime time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.stageTimestamps.ConnectionReady.IsZero() && s.stageTimestamps.WgHandshakeSuccess.IsZero() { + // WireGuard only reports handshake times with second precision, but ConnectionReady + // is captured with microsecond precision. If handshake appears before ConnectionReady + // due to truncation (e.g., handshake at 6.042s truncated to 6.000s), normalize to + // ConnectionReady to avoid negative duration metrics. + if handshakeTime.Before(s.stageTimestamps.ConnectionReady) { + s.stageTimestamps.WgHandshakeSuccess = s.stageTimestamps.ConnectionReady + } else { + s.stageTimestamps.WgHandshakeSuccess = handshakeTime + } + } +} + +// Disconnected sets the mode to reconnection. It is called only when both ICE and Relay have been disconnected at the same time. +func (s *MetricsStages) Disconnected() { + s.mu.Lock() + defer s.mu.Unlock() + + // Reset all timestamps for reconnection + s.stageTimestamps = metrics.ConnectionStageTimestamps{} + s.isReconnectionAttempt = true +} + +func (s *MetricsStages) IsReconnection() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.isReconnectionAttempt +} + +func (s *MetricsStages) GetTimestamps() metrics.ConnectionStageTimestamps { + s.mu.Lock() + defer s.mu.Unlock() + return s.stageTimestamps +} diff --git a/client/internal/peer/metrics_saver_test.go b/client/internal/peer/metrics_saver_test.go new file mode 100644 index 000000000..01c0aa9ac --- /dev/null +++ b/client/internal/peer/metrics_saver_test.go @@ -0,0 +1,125 @@ +package peer + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/metrics" +) + +func TestMetricsStages_RecordSignalingReceived(t *testing.T) { + s := &MetricsStages{} + + s.RecordSignalingReceived() + ts := s.GetTimestamps() + require.False(t, ts.SignalingReceived.IsZero()) + + // Second call should not overwrite + first := ts.SignalingReceived + time.Sleep(time.Millisecond) + s.RecordSignalingReceived() + ts = s.GetTimestamps() + assert.Equal(t, first, ts.SignalingReceived, "should keep the first signaling timestamp") +} + +func TestMetricsStages_RecordConnectionReady(t *testing.T) { + s := &MetricsStages{} + + now := time.Now() + s.RecordConnectionReady(now) + ts := s.GetTimestamps() + assert.Equal(t, now, ts.ConnectionReady) + + // Second call should not overwrite + later := now.Add(time.Second) + s.RecordConnectionReady(later) + ts = s.GetTimestamps() + assert.Equal(t, now, ts.ConnectionReady, "should keep the first connection ready timestamp") +} + +func TestMetricsStages_RecordWGHandshakeSuccess(t *testing.T) { + s := &MetricsStages{} + + connReady := time.Now() + s.RecordConnectionReady(connReady) + + handshake := connReady.Add(500 * time.Millisecond) + s.RecordWGHandshakeSuccess(handshake) + + ts := s.GetTimestamps() + assert.Equal(t, handshake, ts.WgHandshakeSuccess) +} + +func TestMetricsStages_HandshakeBeforeConnectionReady_Normalizes(t *testing.T) { + s := &MetricsStages{} + + connReady := time.Now() + s.RecordConnectionReady(connReady) + + // WG handshake appears before ConnectionReady due to second-precision truncation + handshake := connReady.Add(-100 * time.Millisecond) + s.RecordWGHandshakeSuccess(handshake) + + ts := s.GetTimestamps() + assert.Equal(t, connReady, ts.WgHandshakeSuccess, "should normalize to ConnectionReady when handshake appears earlier") +} + +func TestMetricsStages_HandshakeIgnoredWithoutConnectionReady(t *testing.T) { + s := &MetricsStages{} + + s.RecordWGHandshakeSuccess(time.Now()) + ts := s.GetTimestamps() + assert.True(t, ts.WgHandshakeSuccess.IsZero(), "should not record handshake without connection ready") +} + +func TestMetricsStages_HandshakeRecordedOnce(t *testing.T) { + s := &MetricsStages{} + + connReady := time.Now() + s.RecordConnectionReady(connReady) + + first := connReady.Add(time.Second) + s.RecordWGHandshakeSuccess(first) + + // Second call (rekey) should be ignored + second := connReady.Add(2 * time.Second) + s.RecordWGHandshakeSuccess(second) + + ts := s.GetTimestamps() + assert.Equal(t, first, ts.WgHandshakeSuccess, "should preserve first handshake, ignore rekeys") +} + +func TestMetricsStages_Disconnected(t *testing.T) { + s := &MetricsStages{} + + s.RecordSignalingReceived() + s.RecordConnectionReady(time.Now()) + assert.False(t, s.IsReconnection()) + + s.Disconnected() + + assert.True(t, s.IsReconnection()) + ts := s.GetTimestamps() + assert.True(t, ts.SignalingReceived.IsZero(), "timestamps should be reset after disconnect") + assert.True(t, ts.ConnectionReady.IsZero(), "timestamps should be reset after disconnect") + assert.True(t, ts.WgHandshakeSuccess.IsZero(), "timestamps should be reset after disconnect") +} + +func TestMetricsStages_GetTimestamps(t *testing.T) { + s := &MetricsStages{} + + ts := s.GetTimestamps() + assert.Equal(t, metrics.ConnectionStageTimestamps{}, ts) + + now := time.Now() + s.RecordSignalingReceived() + s.RecordConnectionReady(now) + + ts = s.GetTimestamps() + assert.False(t, ts.SignalingReceived.IsZero()) + assert.Equal(t, now, ts.ConnectionReady) + assert.True(t, ts.WgHandshakeSuccess.IsZero()) +} diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 799a9375e..805a6f24a 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -48,7 +48,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin // EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. // The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management. -func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) { +func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) { w.muEnabled.Lock() if w.enabled { w.muEnabled.Unlock() @@ -56,7 +56,6 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func() } w.log.Debugf("enable WireGuard watcher") - enabledTime := time.Now() w.enabled = true w.muEnabled.Unlock() @@ -65,7 +64,7 @@ func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func() w.log.Warnf("failed to read initial wg stats: %v", err) } - w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake) + w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake) w.muEnabled.Lock() w.enabled = false @@ -89,7 +88,7 @@ func (w *WGWatcher) Reset() { } // wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) { +func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time), enabledTime time.Time, initialHandshake time.Time) { w.log.Infof("WireGuard watcher started") timer := time.NewTimer(wgHandshakeOvertime) @@ -108,6 +107,9 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn if lastHandshake.IsZero() { elapsed := calcElapsed(enabledTime, *handshake) w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + if onHandshakeSuccessFn != nil { + onHandshakeSuccessFn(*handshake) + } } lastHandshake = *handshake diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go index f79405a01..3ce91cd46 100644 --- a/client/internal/peer/wg_watcher_test.go +++ b/client/internal/peer/wg_watcher_test.go @@ -35,9 +35,11 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) { defer cancel() onDisconnected := make(chan struct{}, 1) - go watcher.EnableWgWatcher(ctx, func() { + go watcher.EnableWgWatcher(ctx, time.Now(), func() { mlog.Infof("onDisconnectedFn") onDisconnected <- struct{}{} + }, func(when time.Time) { + mlog.Infof("onHandshakeSuccess: %v", when) }) // wait for initial reading @@ -64,7 +66,7 @@ func TestWGWatcher_ReEnable(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - watcher.EnableWgWatcher(ctx, func() {}) + watcher.EnableWgWatcher(ctx, time.Now(), func() {}, func(when time.Time) {}) }() cancel() @@ -75,9 +77,9 @@ func TestWGWatcher_ReEnable(t *testing.T) { defer cancel() onDisconnected := make(chan struct{}, 1) - go watcher.EnableWgWatcher(ctx, func() { + go watcher.EnableWgWatcher(ctx, time.Now(), func() { onDisconnected <- struct{}{} - }) + }, func(when time.Time) {}) time.Sleep(2 * time.Second) mocWgIface.disconnect() diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index bad616271..e6ef8b876 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,7 +3,9 @@ package client import ( "context" "fmt" + "net" "reflect" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -564,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()) + dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 4bf0d5476..64f2a8789 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "net" "net/netip" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) + upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10)) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 7c9fa021a..0e330bdac 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -1,12 +1,10 @@ #!/usr/bin/env bash set -eEuo pipefail -: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} +: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="30"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() -log_file_path="" _log() { # mimic Go logger's output for easier parsing @@ -33,60 +31,29 @@ on_exit() { fi } -wait_for_message() { - local timeout="${1}" message="${2}" - if test "${timeout}" -eq 0; then - info "not waiting for log line ${message@Q} due to zero timeout." - elif test -n "${log_file_path}"; then - info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) - else - info "log file unsupported, sleeping for ${timeout} seconds..." - sleep "${timeout}" - fi -} - -locate_log_file() { - local log_files_string="${1}" - - while read -r log_file; do - case "${log_file}" in - console | syslog) ;; - *) - log_file_path="${log_file}" - return - ;; - esac - done < <(sed 's#,#\n#g' <<<"${log_files_string}") - - warn "log files parsing for ${log_files_string@Q} is not supported by debug bundles" - warn "please consider removing the \$NB_LOG_FILE or setting it to real file, before gathering debug bundles." -} - wait_for_daemon_startup() { local timeout="${1}" - - if test -n "${log_file_path}"; then - if ! wait_for_message "${timeout}" "started daemon server"; then - warn "log line containing 'started daemon server' not found after ${timeout} seconds" - warn "daemon failed to start, exiting..." - exit 1 - fi - else - warn "daemon service startup not discovered, sleeping ${timeout} instead" - sleep "${timeout}" + if [[ "${timeout}" -eq 0 ]]; then + info "not waiting for daemon startup due to zero timeout." + return fi + + local deadline=$((SECONDS + timeout)) + while [[ "${SECONDS}" -lt "${deadline}" ]]; do + if "${NETBIRD_BIN}" status --check live 2>/dev/null; then + return + fi + sleep 1 + done + + warn "daemon did not become responsive after ${timeout} seconds, exiting..." + exit 1 } -login_if_needed() { - local timeout="${1}" - - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then - info "already logged in, skipping 'netbird up'..." - else - info "logging in..." - "${NETBIRD_BIN}" up - fi +connect() { + info "running 'netbird up'..." + "${NETBIRD_BIN}" up + return $? } main() { @@ -95,9 +62,8 @@ main() { service_pids+=("$!") info "registered new service process 'netbird service run', currently running: ${service_pids[@]@Q}" - locate_log_file "${NB_LOG_FILE}" wait_for_daemon_startup "${NB_ENTRYPOINT_SERVICE_TIMEOUT}" - login_if_needed "${NB_ENTRYPOINT_LOGIN_TIMEOUT}" + connect wait "${service_pids[@]}" } diff --git a/client/server/debug.go b/client/server/debug.go index 4c531efba..81708e576 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -26,6 +26,15 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( log.Warnf("failed to get latest sync response: %v", err) } + var clientMetrics debug.MetricsExporter + if s.connectClient != nil { + if engine := s.connectClient.Engine(); engine != nil { + if cm := engine.GetClientMetrics(); cm != nil { + clientMetrics = cm + } + } + } + var cpuProfileData []byte if s.cpuProfileBuf != nil && !s.cpuProfiling { cpuProfileData = s.cpuProfileBuf.Bytes() @@ -54,6 +63,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( LogPath: s.logFile, CPUProfile: cpuProfileData, RefreshStatus: refreshStatus, + ClientMetrics: clientMetrics, }, debug.BundleConfig{ Anonymize: req.GetAnonymize(), diff --git a/client/status/status.go b/client/status/status.go index f13163a41..8c932bbab 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -25,6 +25,38 @@ import ( "github.com/netbirdio/netbird/version" ) +// DaemonStatus represents the current state of the NetBird daemon. +// These values mirror internal.StatusType but are defined here to avoid an import cycle. +type DaemonStatus string + +const ( + DaemonStatusIdle DaemonStatus = "Idle" + DaemonStatusConnecting DaemonStatus = "Connecting" + DaemonStatusConnected DaemonStatus = "Connected" + DaemonStatusNeedsLogin DaemonStatus = "NeedsLogin" + DaemonStatusLoginFailed DaemonStatus = "LoginFailed" + DaemonStatusSessionExpired DaemonStatus = "SessionExpired" +) + +// ParseDaemonStatus converts a raw status string to DaemonStatus. +// Unrecognized values are preserved as-is to remain visible during version skew. +func ParseDaemonStatus(s string) DaemonStatus { + return DaemonStatus(s) +} + +// ConvertOptions holds parameters for ConvertToStatusOutputOverview. +type ConvertOptions struct { + Anonymize bool + DaemonVersion string + DaemonStatus DaemonStatus + StatusFilter string + PrefixNamesFilter []string + PrefixNamesFilterMap map[string]struct{} + IPsFilter map[string]struct{} + ConnectionTypeFilter string + ProfileName string +} + type PeerStateDetailOutput struct { FQDN string `json:"fqdn" yaml:"fqdn"` IP string `json:"netbirdIp" yaml:"netbirdIp"` @@ -102,6 +134,7 @@ type OutputOverview struct { Peers PeersStateOutput `json:"peers" yaml:"peers"` CliVersion string `json:"cliVersion" yaml:"cliVersion"` DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"` + DaemonStatus DaemonStatus `json:"daemonStatus" yaml:"daemonStatus"` ManagementState ManagementStateOutput `json:"management" yaml:"management"` SignalState SignalStateOutput `json:"signal" yaml:"signal"` Relays RelayStateOutput `json:"relays" yaml:"relays"` @@ -120,7 +153,8 @@ type OutputOverview struct { SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"` } -func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { +// ConvertToStatusOutputOverview converts protobuf status to the output overview. +func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertOptions) OutputOverview { managementState := pbFullStatus.GetManagementState() managementOverview := ManagementStateOutput{ URL: managementState.GetURL(), @@ -137,12 +171,13 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da relayOverview := mapRelays(pbFullStatus.GetRelays()) sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState()) - peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) + peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter) overview := OutputOverview{ Peers: peersOverview, CliVersion: version.NetbirdVersion(), - DaemonVersion: daemonVersion, + DaemonVersion: opts.DaemonVersion, + DaemonStatus: opts.DaemonStatus, ManagementState: managementOverview, SignalState: signalOverview, Relays: relayOverview, @@ -157,11 +192,11 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, da NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), Events: mapEvents(pbFullStatus.GetEvents()), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), - ProfileName: profName, + ProfileName: opts.ProfileName, SSHServerState: sshServerOverview, } - if anon { + if opts.Anonymize { anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) anonymizeOverview(anonymizer, &overview) } diff --git a/client/status/status_test.go b/client/status/status_test.go index b02d78d64..7754eebae 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -176,6 +176,7 @@ var overview = OutputOverview{ Events: []SystemEventOutput{}, CliVersion: version.NetbirdVersion(), DaemonVersion: "0.14.1", + DaemonStatus: DaemonStatusConnected, ManagementState: ManagementStateOutput{ URL: "my-awesome-management.com:443", Connected: true, @@ -238,7 +239,10 @@ var overview = OutputOverview{ } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { - convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "") + convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), ConvertOptions{ + DaemonVersion: resp.GetDaemonVersion(), + DaemonStatus: ParseDaemonStatus(resp.GetStatus()), + }) assert.Equal(t, overview, convertedResult) } @@ -329,6 +333,7 @@ func TestParsingToJSON(t *testing.T) { }, "cliVersion": "development", "daemonVersion": "0.14.1", + "daemonStatus": "Connected", "management": { "url": "my-awesome-management.com:443", "connected": true, @@ -452,6 +457,7 @@ func TestParsingToYAML(t *testing.T) { networks: [] cliVersion: development daemonVersion: 0.14.1 +daemonStatus: Connected management: url: my-awesome-management.com:443 connected: true diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 0574e53d0..b1e0aec41 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -324,6 +324,7 @@ type serviceClient struct { exitNodeMu sync.Mutex mExitNodeItems []menuHandler exitNodeRetryCancel context.CancelFunc + mExitNodeSeparator *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window diff --git a/client/ui/network.go b/client/ui/network.go index ed03f5ada..571e871bb 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -421,6 +421,10 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { node.Remove() } s.mExitNodeItems = nil + if s.mExitNodeSeparator != nil { + s.mExitNodeSeparator.Remove() + s.mExitNodeSeparator = nil + } if s.mExitNodeDeselectAll != nil { s.mExitNodeDeselectAll.Remove() s.mExitNodeDeselectAll = nil @@ -453,31 +457,37 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { } if showDeselectAll { - s.mExitNode.AddSeparator() - deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") - s.mExitNodeDeselectAll = deselectAllItem - go func() { - for { - _, ok := <-deselectAllItem.ClickedCh - if !ok { - // channel closed: exit the goroutine - return - } - exitNodes, err := s.handleExitNodeMenuDeselectAll() - if err != nil { - log.Warnf("failed to handle deselect all exit nodes: %v", err) - } else { - s.exitNodeMu.Lock() - s.recreateExitNodeMenu(exitNodes) - s.exitNodeMu.Unlock() - } - } - - }() + s.addExitNodeDeselectAll() } } +func (s *serviceClient) addExitNodeDeselectAll() { + sep := s.mExitNode.AddSubMenuItem("───────────────", "") + sep.Disable() + s.mExitNodeSeparator = sep + + deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") + s.mExitNodeDeselectAll = deselectAllItem + + go func() { + for { + _, ok := <-deselectAllItem.ClickedCh + if !ok { + return + } + exitNodes, err := s.handleExitNodeMenuDeselectAll() + if err != nil { + log.Warnf("failed to handle deselect all exit nodes: %v", err) + } else { + s.exitNodeMu.Lock() + s.recreateExitNodeMenu(exitNodes) + s.exitNodeMu.Unlock() + } + } + }() +} + func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) { ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout) defer cancel() diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 26022ffc7..d8e50ab6d 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -18,7 +18,6 @@ import ( "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/version" ) const ( @@ -350,7 +349,7 @@ func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) pbFullStatus := fullStatus.ToProto() - return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, false, version.NetbirdVersion(), "", nil, nil, nil, "", ""), nil + return nbstatus.ConvertToStatusOutputOverview(pbFullStatus, nbstatus.ConvertOptions{}), nil } // createStatusMethod creates the status method that returns JSON diff --git a/go.mod b/go.mod index 0a481a7af..31d50e10e 100644 --- a/go.mod +++ b/go.mod @@ -17,23 +17,23 @@ require ( github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.46.0 - golang.org/x/sys v0.39.0 + golang.org/x/crypto v0.48.0 + golang.org/x/sys v0.41.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.77.0 - google.golang.org/protobuf v1.36.10 + google.golang.org/grpc v1.79.3 + google.golang.org/protobuf v1.36.11 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) require ( fyne.io/fyne/v2 v2.7.0 fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 - github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 @@ -102,21 +102,21 @@ require ( github.com/vmihailenco/msgpack/v5 v5.4.1 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 - go.opentelemetry.io/otel v1.38.0 - go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.38.0 - go.opentelemetry.io/otel/sdk/metric v1.38.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 + go.opentelemetry.io/otel v1.42.0 + go.opentelemetry.io/otel/exporters/prometheus v0.64.0 + go.opentelemetry.io/otel/metric v1.42.0 + go.opentelemetry.io/otel/sdk/metric v1.42.0 go.uber.org/mock v0.5.2 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20251113184115-a159579294ab - golang.org/x/mod v0.30.0 - golang.org/x/net v0.47.0 + golang.org/x/mod v0.32.0 + golang.org/x/net v0.51.0 golang.org/x/oauth2 v0.34.0 golang.org/x/sync v0.19.0 - golang.org/x/term v0.38.0 + golang.org/x/term v0.40.0 golang.org/x/time v0.14.0 google.golang.org/api v0.257.0 gopkg.in/yaml.v3 v3.0.1 @@ -145,7 +145,6 @@ require ( github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/awnumar/memcall v0.4.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect @@ -254,12 +253,13 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.66.1 // indirect - github.com/prometheus/procfs v0.16.1 // indirect + github.com/prometheus/common v0.67.5 // indirect + github.com/prometheus/otlptranslator v1.0.0 // indirect + github.com/prometheus/procfs v0.19.2 // indirect github.com/russellhaering/goxmldsig v1.5.0 // indirect github.com/rymdport/portal v0.4.2 // indirect github.com/shirou/gopsutil/v4 v4.25.1 // indirect - github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/shoenig/go-m1cpu v0.2.1 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect @@ -274,15 +274,15 @@ require ( github.com/zeebo/blake3 v0.2.3 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect + go.opentelemetry.io/otel/sdk v1.42.0 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect go.uber.org/multierr v1.11.0 // indirect - go.yaml.in/yaml/v2 v2.4.2 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.32.0 // indirect - golang.org/x/tools v0.39.0 // indirect + golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/go.sum b/go.sum index 9fbc2ce0a..a1d2bb71f 100644 --- a/go.sum +++ b/go.sum @@ -34,8 +34,6 @@ github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSC github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= -github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo= -github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= @@ -497,10 +495,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= -github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= +github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos= +github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= @@ -520,10 +520,12 @@ github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRB github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8= github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs= github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI= -github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= -github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= +github.com/shoenig/go-m1cpu v0.2.1 h1:yqRB4fvOge2+FyRXFkXqsyMoqPazv14Yyy+iyccT2E4= +github.com/shoenig/go-m1cpu v0.2.1/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk= +github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -611,26 +613,26 @@ github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= -go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= -go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= -go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= -go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= -go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= -go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= -go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= -go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= -go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs= +go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -641,8 +643,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= -go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -656,8 +658,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= @@ -674,8 +676,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -694,8 +696,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= @@ -746,8 +748,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -760,8 +762,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -773,8 +775,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -788,8 +790,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -807,12 +809,12 @@ google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3 google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= -google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= -google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= -google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= -google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -823,8 +825,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= -google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 7cb0f3908..d3f8f44ff 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -154,9 +154,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs return err } - eventsToStore = append(eventsToStore, func() { - m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) - }) + if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") { + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + } return nil }) diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index 861d026a7..859f1c5b2 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -17,6 +17,9 @@ type Domain struct { // SupportsCustomPorts is populated at query time for free domains from the // proxy cluster capabilities. Not persisted. SupportsCustomPorts *bool `gorm:"-"` + // RequireSubdomain is populated at query time. When true, the domain + // cannot be used bare and a subdomain label must be prepended. Not persisted. + RequireSubdomain *bool `gorm:"-"` } // EventMeta returns activity event metadata for a domain diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index d26a6a418..640ab28a5 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -47,6 +47,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain { Type: domainTypeToApi(d.Type), Validated: d.Validated, SupportsCustomPorts: d.SupportsCustomPorts, + RequireSubdomain: d.RequireSubdomain, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/domain_test.go b/management/internals/modules/reverseproxy/domain/manager/domain_test.go new file mode 100644 index 000000000..523920a99 --- /dev/null +++ b/management/internals/modules/reverseproxy/domain/manager/domain_test.go @@ -0,0 +1,172 @@ +package manager + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" +) + +func TestExtractClusterFromFreeDomain(t *testing.T) { + clusters := []string{"eu1.proxy.netbird.io", "us1.proxy.netbird.io"} + + tests := []struct { + name string + domain string + wantOK bool + wantVal string + }{ + { + name: "subdomain of cluster matches", + domain: "myapp.eu1.proxy.netbird.io", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "deep subdomain of cluster matches", + domain: "foo.bar.eu1.proxy.netbird.io", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "bare cluster domain matches", + domain: "eu1.proxy.netbird.io", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "unrelated domain does not match", + domain: "example.com", + wantOK: false, + }, + { + name: "partial suffix does not match", + domain: "fakeu1.proxy.netbird.io", + wantOK: false, + }, + { + name: "second cluster matches", + domain: "app.us1.proxy.netbird.io", + wantOK: true, + wantVal: "us1.proxy.netbird.io", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cluster, ok := ExtractClusterFromFreeDomain(tc.domain, clusters) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantVal, cluster) + } + }) + } +} + +func TestExtractClusterFromCustomDomains(t *testing.T) { + customDomains := []*domain.Domain{ + {Domain: "example.com", TargetCluster: "eu1.proxy.netbird.io"}, + {Domain: "proxy.corp.io", TargetCluster: "us1.proxy.netbird.io"}, + } + + tests := []struct { + name string + domain string + wantOK bool + wantVal string + }{ + { + name: "subdomain of custom domain matches", + domain: "app.example.com", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "bare custom domain matches", + domain: "example.com", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "deep subdomain of custom domain matches", + domain: "a.b.example.com", + wantOK: true, + wantVal: "eu1.proxy.netbird.io", + }, + { + name: "subdomain of multi-level custom domain matches", + domain: "app.proxy.corp.io", + wantOK: true, + wantVal: "us1.proxy.netbird.io", + }, + { + name: "bare multi-level custom domain matches", + domain: "proxy.corp.io", + wantOK: true, + wantVal: "us1.proxy.netbird.io", + }, + { + name: "unrelated domain does not match", + domain: "other.com", + wantOK: false, + }, + { + name: "partial suffix does not match custom domain", + domain: "fakeexample.com", + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains) + assert.Equal(t, tc.wantOK, ok) + if ok { + assert.Equal(t, tc.wantVal, cluster) + } + }) + } +} + +func TestExtractClusterFromCustomDomains_OverlappingDomains(t *testing.T) { + customDomains := []*domain.Domain{ + {Domain: "example.com", TargetCluster: "cluster-generic"}, + {Domain: "app.example.com", TargetCluster: "cluster-app"}, + } + + tests := []struct { + name string + domain string + wantVal string + }{ + { + name: "exact match on more specific domain", + domain: "app.example.com", + wantVal: "cluster-app", + }, + { + name: "subdomain of more specific domain", + domain: "api.app.example.com", + wantVal: "cluster-app", + }, + { + name: "subdomain of generic domain", + domain: "other.example.com", + wantVal: "cluster-generic", + }, + { + name: "bare generic domain", + domain: "example.com", + wantVal: "cluster-generic", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains) + assert.True(t, ok) + assert.Equal(t, tc.wantVal, cluster) + }) + } +} diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 813027ea2..c6c41bfe5 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,18 +31,15 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) -} - -type clusterCapabilities interface { - ClusterSupportsCustomPorts(clusterAddr string) *bool + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool } type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - clusterCapabilities clusterCapabilities - permissionsManager permissions.Manager + store store + validator domain.Validator + proxyManager proxyManager + permissionsManager permissions.Manager accountManager account.Manager } @@ -56,11 +53,6 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio } } -// SetClusterCapabilities sets the cluster capabilities provider for domain queries. -func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) { - m.clusterCapabilities = caps -} - func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -96,9 +88,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d Type: domain.TypeFree, Validated: true, } - if m.clusterCapabilities != nil { - d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster) - } + d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster) + d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster) ret = append(ret, d) } @@ -112,9 +103,11 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d Type: domain.TypeCustom, Validated: d.Validated, } - if m.clusterCapabilities != nil && d.TargetCluster != "" { - cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster) + if d.TargetCluster != "" { + cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster) } + // Custom domains never require a subdomain by default since + // the account owns them and should be able to use the bare domain. ret = append(ret, cd) } @@ -302,13 +295,19 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } -func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) { - for _, customDomain := range customDomains { - if strings.HasSuffix(domain, "."+customDomain.Domain) { - return customDomain.TargetCluster, true +func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { + bestCluster := "" + bestLen := -1 + for _, cd := range customDomains { + if serviceDomain != cd.Domain && !strings.HasSuffix(serviceDomain, "."+cd.Domain) { + continue + } + if l := len(cd.Domain); l > bestLen { + bestLen = l + bestCluster = cd.TargetCluster } } - return "", false + return bestCluster, bestLen >= 0 } // ExtractClusterFromFreeDomain extracts the cluster address from a free domain. diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 67a8e74fa..0368b84de 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,10 +11,13 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error Disconnect(ctx context.Context, proxyID string) error - Heartbeat(ctx context.Context, proxyID string) error + Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusters(ctx context.Context) ([]Cluster, error) + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error } @@ -33,5 +36,4 @@ type Controller interface { RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error GetProxiesForCluster(clusterAddr string) []string - ClusterSupportsCustomPorts(clusterAddr string) *bool } diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go index acb49c45b..e5b3e9886 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/controller.go +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -72,11 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster return nil } -// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports. -func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool { - return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr) -} - // GetProxiesForCluster returns all proxy IDs registered for a specific cluster. func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { proxySet, ok := c.clusterProxies.Load(clusterAddr) diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index 4c0964b5c..a92fffab9 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -13,8 +13,11 @@ import ( // store defines the interface for proxy persistence operations type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID string) error + UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error } @@ -37,9 +40,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { }, nil } -// Connect registers a new proxy connection in the database -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// Connect registers a new proxy connection in the database. +// capabilities may be nil for old proxies that do not report them. +func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { now := time.Now() + var caps proxy.Capabilities + if capabilities != nil { + caps = *capabilities + } p := &proxy.Proxy{ ID: proxyID, ClusterAddress: clusterAddress, @@ -47,6 +55,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress LastSeen: now, ConnectedAt: &now, Status: "connected", + Capabilities: caps, } if err := m.store.SaveProxy(ctx, p); err != nil { @@ -86,11 +95,13 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error { } // Heartbeat updates the proxy's last seen timestamp -func (m Manager) Heartbeat(ctx context.Context, proxyID string) error { - if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil { +func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) return err } + + log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID) m.metrics.IncrementProxyHeartbeatCount() return nil } @@ -105,6 +116,28 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error return addresses, nil } +// GetActiveClusters returns all active proxy clusters with their connected proxy count. +func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { + clusters, err := m.store.GetActiveProxyClusters(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) + return nil, err + } + return clusters, nil +} + +// ClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr) +} + +// ClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterRequireSubdomain(ctx, clusterAddr) +} + // CleanupStale removes proxies that haven't sent heartbeat in the specified duration func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index b07a21122..97466c503 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -50,18 +50,46 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration) } -// Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// ClusterSupportsCustomPorts mocks base method. +func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr) +} + +// ClusterRequireSubdomain mocks base method. +func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain. +func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr) +} + +// Connect mocks base method. +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) ret0, _ := ret[0].(error) return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) } // Disconnect mocks base method. @@ -93,18 +121,33 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) } -// Heartbeat mocks base method. -func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error { +// GetActiveClusters mocks base method. +func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID) + ret := m.ctrl.Call(m, "GetActiveClusters", ctx) + ret0, _ := ret[0].([]Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusters indicates an expected call of GetActiveClusters. +func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) +} + +// Heartbeat mocks base method. +func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress) ret0, _ := ret[0].(error) return ret0 } // Heartbeat indicates an expected call of Heartbeat. -func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) } // MockController is a mock of Controller interface. @@ -144,20 +187,6 @@ func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) } -// ClusterSupportsCustomPorts mocks base method. -func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. -func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) -} - // GetProxiesForCluster mocks base method. func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 699e1ed02..4102e50fe 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -2,6 +2,17 @@ package proxy import "time" +// Capabilities describes what a proxy can handle, as reported via gRPC. +// Nil fields mean the proxy never reported this capability. +type Capabilities struct { + // SupportsCustomPorts indicates whether this proxy can bind arbitrary + // ports for TCP/UDP services. TLS uses SNI routing and is not gated. + SupportsCustomPorts *bool + // RequireSubdomain indicates whether a subdomain label is required in + // front of the cluster domain. + RequireSubdomain *bool +} + // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` @@ -11,6 +22,7 @@ type Proxy struct { ConnectedAt *time.Time DisconnectedAt *time.Time Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time } @@ -18,3 +30,9 @@ type Proxy struct { func (Proxy) TableName() string { return "proxies" } + +// Cluster represents a group of proxy nodes serving the same address. +type Cluster struct { + Address string + ConnectedProxies int +} diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index 39fd7e3ae..a49cbea35 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -4,9 +4,12 @@ package service import ( "context" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" ) type Manager interface { + GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index bdc1f3e65..cc5ccbb8e 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" ) // MockManager is a mock of Manager interface. @@ -107,6 +108,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID) } +// GetActiveClusters mocks base method. +func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusters indicates an expected call of GetActiveClusters. +func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID) +} + // GetAllServices mocks base method. func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index c53219d2e..cd81efa88 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -34,6 +34,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma accesslogsmanager.RegisterEndpoints(router, accessLogsManager) + router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") @@ -177,3 +178,27 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } + +func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + apiClusters := make([]api.ProxyCluster, 0, len(clusters)) + for _, c := range clusters { + apiClusters = append(apiClusters, api.ProxyCluster{ + Address: c.Address, + ConnectedProxies: c.ConnectedProxies, + }) + } + + util.WriteJSONObject(r.Context(), w, apiClusters) +} diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index c7a61ddcf..4a7647d90 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -75,10 +75,13 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor require.NoError(t, err) mockCtrl := proxy.NewMockController(ctrl) - mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes() mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes() + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes() + accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, @@ -92,6 +95,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountManager: accountMgr, permissionsManager: permissions.NewManager(testStore), proxyController: mockCtrl, + capabilities: mockCaps, clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, } mgr.exposeReaper = &exposeReaper{manager: mgr} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 65177bf5d..db393ef38 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -14,6 +14,8 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" @@ -73,22 +75,30 @@ type ClusterDeriver interface { GetClusterDomains() []string } +// CapabilityProvider queries proxy cluster capabilities from the database. +type CapabilityProvider interface { + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool +} + type Manager struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager proxyController proxy.Controller + capabilities CapabilityProvider clusterDeriver ClusterDeriver exposeReaper *exposeReaper } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager { +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { mgr := &Manager{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, proxyController: proxyController, + capabilities: capabilities, clusterDeriver: clusterDeriver, } mgr.exposeReaper = &exposeReaper{manager: mgr} @@ -100,6 +110,19 @@ func (m *Manager) StartExposeReaper(ctx context.Context) { m.exposeReaper.StartExposeReaper(ctx) } +// GetActiveClusters returns all active proxy clusters with their connected proxy count. +func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetActiveProxyClusters(ctx) +} + func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -221,6 +244,10 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err) } service.ProxyCluster = proxyCluster + + if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil { + return err + } } service.AccountID = accountID @@ -246,6 +273,20 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return nil } +// validateSubdomainRequirement checks whether the domain can be used bare +// (without a subdomain label) on the given cluster. If the cluster reports +// require_subdomain=true and the domain equals the cluster domain, it rejects. +func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error { + if domain != cluster { + return nil + } + requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster) + if requireSub != nil && *requireSub { + return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain) + } + return nil +} + func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if svc.Domain != "" { @@ -279,7 +320,7 @@ func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service if !service.IsL4Protocol(svc.Mode) { return nil } - customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster) + customPorts := m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster) if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { if svc.Source != service.SourceEphemeral { return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) @@ -474,53 +515,65 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se var updateInfo serviceUpdateInfo err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) - if err != nil { - return err - } - - if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { - return err - } - - updateInfo.oldCluster = existingService.ProxyCluster - updateInfo.domainChanged = existingService.Domain != service.Domain - - if updateInfo.domainChanged { - if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { - return err - } - } else { - service.ProxyCluster = existingService.ProxyCluster - } - - m.preserveExistingAuthSecrets(service, existingService) - if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { - return err - } - m.preserveServiceMetadata(service, existingService) - m.preserveListenPort(service, existingService) - updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled - - if err := m.ensureL4Port(ctx, transaction, service); err != nil { - return err - } - if err := m.checkPortConflict(ctx, transaction, service); err != nil { - return err - } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } - if err := transaction.UpdateService(ctx, service); err != nil { - return fmt.Errorf("update service: %w", err) - } - - return nil + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo) }) return &updateInfo, err } +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error { + existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) + if err != nil { + return err + } + + if existingService.Terminated { + return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated") + } + + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } + + updateInfo.oldCluster = existingService.ProxyCluster + updateInfo.domainChanged = existingService.Domain != service.Domain + + if updateInfo.domainChanged { + if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { + return err + } + } else { + service.ProxyCluster = existingService.ProxyCluster + } + + if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { + return err + } + + m.preserveExistingAuthSecrets(service, existingService) + if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { + return err + } + m.preserveServiceMetadata(service, existingService) + m.preserveListenPort(service, existingService) + updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled + + if err := m.ensureL4Port(ctx, transaction, service); err != nil { + return err + } + if err := m.checkPortConflict(ctx, transaction, service); err != nil { + return err + } + if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { + return err + } + if err := transaction.UpdateService(ctx, service); err != nil { + return fmt.Errorf("update service: %w", err) + } + + return nil +} + func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err @@ -636,18 +689,12 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco for _, target := range targets { switch target.TargetType { case service.TargetTypePeer: - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil { + return err } case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: - if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) + if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil { + return err } default: return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) @@ -656,6 +703,39 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } +func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + } + return nil +} + +func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) + } + return validateResourceTargetType(target, resource) +} + +// validateResourceTargetType checks that target_type matches the actual network resource type. +func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error { + expected := resourcetypes.NetworkResourceType(target.TargetType) + if resource.Type != expected { + return status.Errorf(status.InvalidArgument, + "target %q has target_type %q but resource is of type %q", + target.TargetId, target.TargetType, resource.Type, + ) + } + return nil +} + func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index d23c91017..f6e532118 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/mock_server" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -1214,3 +1215,126 @@ func TestValidateProtocolChange(t *testing.T) { }) } } + +func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + tests := []struct { + name string + targetType rpservice.TargetType + resourceType resourcetypes.NetworkResourceType + wantErr bool + }{ + {"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false}, + {"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false}, + {"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false}, + {"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true}, + {"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true}, + {"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true}, + {"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore.EXPECT(). + GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1"). + Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil) + + targets := []*rpservice.Target{ + {TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"}, + } + err := validateTargetReferences(ctx, mockStore, accountID, targets) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "target_type") + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateTargetReferences_PeerValid(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + mockStore.EXPECT(). + GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1"). + Return(&nbpeer.Peer{}, nil) + + targets := []*rpservice.Target{ + {TargetId: "peer-1", TargetType: rpservice.TargetTypePeer}, + } + require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets)) +} + +func TestValidateSubdomainRequirement(t *testing.T) { + ptrBool := func(b bool) *bool { return &b } + + tests := []struct { + name string + domain string + cluster string + requireSubdomain *bool + wantErr bool + }{ + { + name: "subdomain present, require_subdomain true", + domain: "app.eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(true), + wantErr: false, + }, + { + name: "bare cluster domain, require_subdomain true", + domain: "eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(true), + wantErr: true, + }, + { + name: "bare cluster domain, require_subdomain false", + domain: "eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(false), + wantErr: false, + }, + { + name: "bare cluster domain, require_subdomain nil (default)", + domain: "eu1.proxy.netbird.io", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: nil, + wantErr: false, + }, + { + name: "custom domain apex is not the cluster", + domain: "example.com", + cluster: "eu1.proxy.netbird.io", + requireSubdomain: ptrBool(true), + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes() + + mgr := &Manager{capabilities: mockCaps} + err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster) + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "requires a subdomain label") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 6c7c80806..d956013ea 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -184,6 +184,7 @@ type Service struct { ProxyCluster string `gorm:"index"` Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"` Enabled bool + Terminated bool PassHostHeader bool RewriteRedirects bool Auth AuthConfig `gorm:"serializer:json"` @@ -256,13 +257,15 @@ func (s *Service) ToAPIResponse() *api.Service { Protocol: api.ServiceTargetProtocol(target.Protocol), TargetId: target.TargetId, TargetType: api.ServiceTargetTargetType(target.TargetType), - Enabled: target.Enabled, + Enabled: target.Enabled && !s.Terminated, } opts := targetOptionsToAPI(target.Options) if opts == nil { opts = &api.ServiceTargetOptions{} } - opts.ProxyProtocol = &target.ProxyProtocol + if target.ProxyProtocol { + opts.ProxyProtocol = &target.ProxyProtocol + } st.Options = opts apiTargets = append(apiTargets, st) } @@ -284,7 +287,8 @@ func (s *Service) ToAPIResponse() *api.Service { Name: s.Name, Domain: s.Domain, Targets: apiTargets, - Enabled: s.Enabled, + Enabled: s.Enabled && !s.Terminated, + Terminated: &s.Terminated, PassHostHeader: &s.PassHostHeader, RewriteRedirects: &s.RewriteRedirects, Auth: authConfig, @@ -790,7 +794,7 @@ func (s *Service) validateL4Target(target *Target) error { return errors.New("target_id is required for L4 services") } switch target.TargetType { - case TargetTypePeer, TargetTypeHost: + case TargetTypePeer, TargetTypeHost, TargetTypeDomain: // OK case TargetTypeSubnet: if target.Host == "" { @@ -848,7 +852,7 @@ func IsPortBasedProtocol(mode string) bool { } const ( - maxCustomHeaders = 16 + maxCustomHeaders = 16 maxHeaderKeyLen = 128 maxHeaderValueLen = 4096 ) @@ -945,7 +949,6 @@ func containsCRLF(s string) bool { } func validateHeaderAuths(headers []*HeaderAuthConfig) error { - seen := make(map[string]struct{}) for i, h := range headers { if h == nil || !h.Enabled { continue @@ -966,10 +969,6 @@ func validateHeaderAuths(headers []*HeaderAuthConfig) error { if canonical == "Host" { return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i) } - if _, dup := seen[canonical]; dup { - return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header) - } - seen[canonical] = struct{}{} if len(h.Value) > maxHeaderValueLen { return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen) } @@ -1128,6 +1127,7 @@ func (s *Service) Copy() *Service { ProxyCluster: s.ProxyCluster, Targets: targets, Enabled: s.Enabled, + Terminated: s.Terminated, PassHostHeader: s.PassHostHeader, RewriteRedirects: s.RewriteRedirects, Auth: authCopy, diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 9daf729fe..ff54cb79f 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -847,6 +847,32 @@ func TestValidate_TLSSubnetValid(t *testing.T) { require.NoError(t, rp.Validate()) } +func TestValidate_L4DomainTargetValid(t *testing.T) { + modes := []struct { + mode string + port uint16 + proto string + }{ + {"tcp", 5432, "tcp"}, + {"tls", 443, "tcp"}, + {"udp", 5432, "udp"}, + } + for _, m := range modes { + t.Run(m.mode, func(t *testing.T) { + rp := &Service{ + Name: m.mode + "-domain", + Mode: m.mode, + Domain: "cluster.test", + ListenPort: m.port, + Targets: []*Target{ + {TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) + }) + } +} + func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { rp := validProxy() rp.Targets[0].ProxyProtocol = true @@ -909,3 +935,107 @@ func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) { req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"} require.NoError(t, req.Validate()) } + +func TestValidate_HeaderAuths(t *testing.T) { + t.Run("single valid header", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "X-API-Key", Value: "secret"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("multiple headers same canonical name allowed", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Authorization", Value: "Bearer token-1"}, + {Enabled: true, Header: "Authorization", Value: "Bearer token-2"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("multiple headers different case same canonical allowed", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "x-api-key", Value: "key-1"}, + {Enabled: true, Header: "X-Api-Key", Value: "key-2"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("multiple different headers allowed", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Authorization", Value: "Bearer tok"}, + {Enabled: true, Header: "X-API-Key", Value: "key"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("empty header name rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "", Value: "val"}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "header name is required") + }) + + t.Run("hop-by-hop header rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Connection", Value: "val"}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "hop-by-hop") + }) + + t.Run("host header rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "Host", Value: "val"}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "Host header cannot be used") + }) + + t.Run("disabled entries skipped", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: false, Header: "", Value: ""}, + {Enabled: true, Header: "X-Key", Value: "val"}, + }, + } + require.NoError(t, rp.Validate()) + }) + + t.Run("value too long rejected", func(t *testing.T) { + rp := validProxy() + rp.Auth = AuthConfig{ + HeaderAuths: []*HeaderAuthConfig{ + {Enabled: true, Header: "X-Key", Value: strings.Repeat("a", maxHeaderValueLen+1)}, + }, + } + err := rp.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum length") + }) +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index a32cf6046..6064bd5b6 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -195,7 +195,7 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ServiceManager() service.Manager { return Create(s, func() service.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) }) } @@ -212,9 +212,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) - s.AfterInit(func(s *BaseServer) { - m.SetClusterCapabilities(s.ServiceProxyController()) - }) return &m }) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index fd993fb40..07732cea6 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -123,7 +123,7 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil { + if err := s.proxyManager.CleanupStale(ctx, 1*time.Hour); err != nil { log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err) } } @@ -182,9 +182,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) } - // Register proxy in database - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) + // Register proxy in database with capabilities + var caps *proxy.Capabilities + if c := req.GetCapabilities(); c != nil { + caps = &proxy.Capabilities{ + SupportsCustomPorts: c.SupportsCustomPorts, + RequireSubdomain: c.RequireSubdomain, + } + } + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) + s.connectedProxies.Delete(proxyID) + if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) + } + return status.Errorf(codes.Internal, "register proxy in database: %v", err) } log.WithFields(log.Fields{ @@ -215,7 +227,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest go s.sender(conn, errChan) // Start heartbeat goroutine - go s.heartbeat(connCtx, proxyID) + go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) select { case err := <-errChan: @@ -226,14 +238,14 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } // heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { +func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil { + if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) } case <-ctx.Done(): @@ -297,6 +309,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + if !proxyAcceptsMapping(conn, m) { + continue + } mappings = append(mappings, m) } return mappings, nil @@ -445,22 +460,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd log.Debugf("Sending service update to cluster %s", clusterAddr) for _, proxyID := range proxyIDs { - if connVal, ok := s.connectedProxies.Load(proxyID); ok { - conn := connVal.(*proxyConnection) - msg := s.perProxyMessage(updateResponse, proxyID) - if msg == nil { - continue - } - select { - case conn.sendChan <- msg: - log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) - default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) - } + connVal, ok := s.connectedProxies.Load(proxyID) + if !ok { + continue + } + conn := connVal.(*proxyConnection) + if !proxyAcceptsMapping(conn, update) { + log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) + continue + } + msg := s.perProxyMessage(updateResponse, proxyID) + if msg == nil { + continue + } + select { + case conn.sendChan <- msg: + log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + default: + log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) } } } +// proxyAcceptsMapping returns whether the proxy should receive this mapping. +// Old proxies that never reported capabilities are skipped for non-TLS L4 +// mappings with a custom listen port, since they don't understand the +// protocol. Proxies that report capabilities (even SupportsCustomPorts=false) +// are new enough to handle the mapping. TLS uses SNI routing and works on +// any proxy. Delete operations are always sent so proxies can clean up. +func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { + return true + } + if mapping.ListenPort == 0 || mapping.Mode == "tls" { + return true + } + // Old proxies that never reported capabilities don't understand + // custom port mappings. + return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil +} + // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. @@ -508,35 +547,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } } -// ClusterSupportsCustomPorts returns whether any connected proxy in the given -// cluster reports custom port support. Returns nil if no proxy has reported -// capabilities (old proxies that predate the field). -func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool { - if s.proxyController == nil { - return nil - } - - var hasCapabilities bool - for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { - connVal, ok := s.connectedProxies.Load(pid) - if !ok { - continue - } - conn := connVal.(*proxyConnection) - if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil { - continue - } - if *conn.capabilities.SupportsCustomPorts { - return ptr(true) - } - hasCapabilities = true - } - if hasCapabilities { - return ptr(false) - } - return nil -} - func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 22fe4506b..0fa9a0dc1 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/types" ) @@ -90,6 +91,10 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} +func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { + return nil, nil +} + type mockUsersManager struct { users map[string]*types.User err error diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 1a4ea3330..d5aed3dee 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -53,10 +53,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus return nil } -func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { - return ptr(true) -} - func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { c.mu.Lock() defer c.mu.Unlock() @@ -351,14 +347,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { const cluster = "proxy.example.com" - // Proxy A supports custom ports. - chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) - // Proxy B does NOT support custom ports (shared cloud proxy). - chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + // Modern proxy reports capabilities. + chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Legacy proxy never reported capabilities (nil). + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) ctx := context.Background() - // TLS passthrough works on all proxies regardless of custom port support. + // TLS passthrough with custom port: all proxies receive it (SNI routing). tlsMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-tls", @@ -371,12 +367,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) - msgA := drainMapping(chA) - msgB := drainMapping(chB) - assert.NotNil(t, msgA, "proxy-a should receive TLS mapping") - assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)") + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)") - // Send an HTTP mapping: both should receive it. + // TCP mapping with custom port: only modern proxy receives it. + tcpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tcp", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tcp", + ListenPort: 5432, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping") + assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping") + + // HTTP mapping (no listen port): both receive it. httpMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-http", @@ -387,10 +397,16 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) - msgA = drainMapping(chA) - msgB = drainMapping(chB) - assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping") - assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping") + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping") + + // Proxy that reports SupportsCustomPorts=false still receives custom-port + // mappings because it understands the protocol (it's new enough). + chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping") } func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { @@ -404,7 +420,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { const cluster = "proxy.example.com" - chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + // Legacy proxy (no capabilities) still receives TLS since it uses SNI. + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) tlsMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, @@ -417,8 +434,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) - msg := drainMapping(chShared) - assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support") + msg := drainMapping(chLegacy) + assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)") } // TestServiceModifyNotifications exercises every possible modification @@ -585,7 +602,7 @@ func TestServiceModifyNotifications(t *testing.T) { s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) - chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) // TLS passthrough works on all proxies regardless of custom port support s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) @@ -604,7 +621,7 @@ func TestServiceModifyNotifications(t *testing.T) { } s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" - chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) mapping.ListenPort = 0 // default port diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 647e8443b..2f77de86e 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/store" @@ -320,6 +321,10 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { + return nil, nil +} + type testValidateSessionProxyManager struct{} func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { @@ -338,6 +343,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co return nil, nil } +func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { + return nil, nil +} + func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } diff --git a/management/server/account_test.go b/management/server/account_test.go index fdec43617..548cf31d4 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3138,7 +3138,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU if err != nil { return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil)) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil)) return manager, updateManager, nil } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 3bed54e80..922bf4352 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/store" @@ -433,6 +434,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri func (m *testServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { + return nil, nil +} + func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { t.Helper() diff --git a/management/server/http/testing/integration/accounts_handler_integration_test.go b/management/server/http/testing/integration/accounts_handler_integration_test.go new file mode 100644 index 000000000..511730ee5 --- /dev/null +++ b/management/server/http/testing/integration/accounts_handler_integration_test.go @@ -0,0 +1,238 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Accounts_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all accounts", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/accounts", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Account{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + account := got[0] + assert.Equal(t, "test.com", account.Domain) + assert.Equal(t, "private", account.DomainCategory) + assert.Equal(t, true, account.Settings.PeerLoginExpirationEnabled) + assert.Equal(t, 86400, account.Settings.PeerLoginExpiration) + assert.Equal(t, false, account.Settings.RegularUsersViewBlocked) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Accounts_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + trueVal := true + falseVal := false + + tt := []struct { + name string + expectedStatus int + requestBody *api.AccountRequest + verifyResponse func(t *testing.T, account *api.Account) + verifyDB func(t *testing.T, account *types.Account) + }{ + { + name: "Disable peer login expiration", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: false, + PeerLoginExpiration: 86400, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.Equal(t, false, account.Settings.PeerLoginExpirationEnabled) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, false, dbAccount.Settings.PeerLoginExpirationEnabled) + }, + }, + { + name: "Update peer login expiration to 48h", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 172800, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.Equal(t, 172800, account.Settings.PeerLoginExpiration) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, 172800*time.Second, dbAccount.Settings.PeerLoginExpiration) + }, + }, + { + name: "Enable regular users view blocked", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 86400, + RegularUsersViewBlocked: true, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.Equal(t, true, account.Settings.RegularUsersViewBlocked) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, true, dbAccount.Settings.RegularUsersViewBlocked) + }, + }, + { + name: "Enable groups propagation", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 86400, + GroupsPropagationEnabled: &trueVal, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.NotNil(t, account.Settings.GroupsPropagationEnabled) + assert.Equal(t, true, *account.Settings.GroupsPropagationEnabled) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, true, dbAccount.Settings.GroupsPropagationEnabled) + }, + }, + { + name: "Enable JWT groups", + requestBody: &api.AccountRequest{ + Settings: api.AccountSettings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: 86400, + GroupsPropagationEnabled: &falseVal, + JwtGroupsEnabled: &trueVal, + JwtGroupsClaimName: stringPointer("groups"), + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, account *api.Account) { + t.Helper() + assert.NotNil(t, account.Settings.JwtGroupsEnabled) + assert.Equal(t, true, *account.Settings.JwtGroupsEnabled) + assert.NotNil(t, account.Settings.JwtGroupsClaimName) + assert.Equal(t, "groups", *account.Settings.JwtGroupsClaimName) + }, + verifyDB: func(t *testing.T, dbAccount *types.Account) { + t.Helper() + assert.Equal(t, true, dbAccount.Settings.JWTGroupsEnabled) + assert.Equal(t, "groups", dbAccount.Settings.JWTGroupsClaimName) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/accounts/{accountId}", "{accountId}", testing_tools.TestAccountId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + got := &api.Account{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, testing_tools.TestAccountId, got.Id) + assert.Equal(t, "test.com", got.Domain) + tc.verifyResponse(t, got) + + db := testing_tools.GetDB(t, am.GetStore()) + dbAccount := testing_tools.VerifyAccountSettings(t, db) + tc.verifyDB(t, dbAccount) + }) + } + } +} + +func stringPointer(s string) *string { + return &s +} diff --git a/management/server/http/testing/integration/dns_handler_integration_test.go b/management/server/http/testing/integration/dns_handler_integration_test.go new file mode 100644 index 000000000..7ada5e462 --- /dev/null +++ b/management/server/http/testing/integration/dns_handler_integration_test.go @@ -0,0 +1,554 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Nameservers_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all nameservers", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/nameservers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.NameserverGroup{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testNSGroup", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Nameservers_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + nsGroupId string + expectedStatus int + expectGroup bool + }{ + { + name: "Get existing nameserver group", + nsGroupId: "testNSGroupId", + expectedStatus: http.StatusOK, + expectGroup: true, + }, + { + name: "Get non-existing nameserver group", + nsGroupId: "nonExistingNSGroupId", + expectedStatus: http.StatusNotFound, + expectGroup: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectGroup { + got := &api.NameserverGroup{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "testNSGroupId", got.Id) + assert.Equal(t, "testNSGroup", got.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Nameservers_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.PostApiDnsNameserversJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup) + }{ + { + name: "Create nameserver group with single NS", + requestBody: &api.PostApiDnsNameserversJSONRequestBody{ + Name: "newNSGroup", + Description: "a new nameserver group", + Nameservers: []api.Nameserver{ + {Ip: "8.8.8.8", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: false, + Domains: []string{"test.com"}, + Enabled: true, + SearchDomainsEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) { + t.Helper() + assert.NotEmpty(t, nsGroup.Id) + assert.Equal(t, "newNSGroup", nsGroup.Name) + assert.Equal(t, 1, len(nsGroup.Nameservers)) + assert.Equal(t, false, nsGroup.Primary) + }, + }, + { + name: "Create primary nameserver group", + requestBody: &api.PostApiDnsNameserversJSONRequestBody{ + Name: "primaryNS", + Description: "primary nameserver", + Nameservers: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: true, + Domains: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) { + t.Helper() + assert.Equal(t, true, nsGroup.Primary) + }, + }, + { + name: "Create nameserver group with empty groups", + requestBody: &api.PostApiDnsNameserversJSONRequestBody{ + Name: "emptyGroupsNS", + Description: "no groups", + Nameservers: []api.Nameserver{ + {Ip: "8.8.8.8", NsType: "udp", Port: 53}, + }, + Groups: []string{}, + Primary: true, + Domains: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/nameservers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NameserverGroup{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the created NS group directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbNS := testing_tools.VerifyNSGroupInDB(t, db, got.Id) + assert.Equal(t, got.Name, dbNS.Name) + assert.Equal(t, got.Primary, dbNS.Primary) + assert.Equal(t, len(got.Nameservers), len(dbNS.NameServers)) + assert.Equal(t, got.Enabled, dbNS.Enabled) + assert.Equal(t, got.SearchDomainsEnabled, dbNS.SearchDomainsEnabled) + } + }) + } + } +} + +func Test_Nameservers_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + nsGroupId string + requestBody *api.PutApiDnsNameserversNsgroupIdJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup) + }{ + { + name: "Update nameserver group name", + nsGroupId: "testNSGroupId", + requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{ + Name: "updatedNSGroup", + Description: "updated description", + Nameservers: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: false, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) { + t.Helper() + assert.Equal(t, "updatedNSGroup", nsGroup.Name) + assert.Equal(t, "updated description", nsGroup.Description) + }, + }, + { + name: "Update non-existing nameserver group", + nsGroupId: "nonExistingNSGroupId", + requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{ + Name: "whatever", + Nameservers: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + Groups: []string{testing_tools.TestGroupId}, + Primary: true, + Domains: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NameserverGroup{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify the updated NS group directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbNS := testing_tools.VerifyNSGroupInDB(t, db, tc.nsGroupId) + assert.Equal(t, "updatedNSGroup", dbNS.Name) + assert.Equal(t, "updated description", dbNS.Description) + assert.Equal(t, false, dbNS.Primary) + assert.Equal(t, true, dbNS.Enabled) + assert.Equal(t, 1, len(dbNS.NameServers)) + assert.Equal(t, false, dbNS.SearchDomainsEnabled) + } + }) + } + } +} + +func Test_Nameservers_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + nsGroupId string + expectedStatus int + }{ + { + name: "Delete existing nameserver group", + nsGroupId: "testNSGroupId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing nameserver group", + nsGroupId: "nonExistingNSGroupId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify deletion in DB for successful deletes by privileged users + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyNSGroupNotInDB(t, db, tc.nsGroupId) + } + }) + } + } +} + +func Test_DnsSettings_Get(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get DNS settings", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/settings", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := &api.DNSSettings{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.NotNil(t, got.DisabledManagementGroups) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_DnsSettings_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.PutApiDnsSettingsJSONRequestBody + expectedStatus int + verifyResponse func(t *testing.T, settings *api.DNSSettings) + expectedDBDisabledMgmtLen int + expectedDBDisabledMgmtItem string + }{ + { + name: "Update disabled management groups", + requestBody: &api.PutApiDnsSettingsJSONRequestBody{ + DisabledManagementGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, settings *api.DNSSettings) { + t.Helper() + assert.Equal(t, 1, len(settings.DisabledManagementGroups)) + assert.Equal(t, testing_tools.TestGroupId, settings.DisabledManagementGroups[0]) + }, + expectedDBDisabledMgmtLen: 1, + expectedDBDisabledMgmtItem: testing_tools.TestGroupId, + }, + { + name: "Update with empty disabled management groups", + requestBody: &api.PutApiDnsSettingsJSONRequestBody{ + DisabledManagementGroups: []string{}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, settings *api.DNSSettings) { + t.Helper() + assert.Equal(t, 0, len(settings.DisabledManagementGroups)) + }, + expectedDBDisabledMgmtLen: 0, + }, + { + name: "Update with non-existing group", + requestBody: &api.PutApiDnsSettingsJSONRequestBody{ + DisabledManagementGroups: []string{"nonExistingGroupId"}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/dns/settings", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.DNSSettings{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify DNS settings directly in the DB + db := testing_tools.GetDB(t, am.GetStore()) + dbAccount := testing_tools.VerifyAccountSettings(t, db) + assert.Equal(t, tc.expectedDBDisabledMgmtLen, len(dbAccount.DNSSettings.DisabledManagementGroups)) + if tc.expectedDBDisabledMgmtItem != "" { + assert.Contains(t, dbAccount.DNSSettings.DisabledManagementGroups, tc.expectedDBDisabledMgmtItem) + } + } + }) + } + } +} diff --git a/management/server/http/testing/integration/events_handler_integration_test.go b/management/server/http/testing/integration/events_handler_integration_test.go new file mode 100644 index 000000000..6611b60ee --- /dev/null +++ b/management/server/http/testing/integration/events_handler_integration_test.go @@ -0,0 +1,105 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Events_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all events", func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false) + + // First, perform a mutation to generate an event (create a group as admin) + groupBody, err := json.Marshal(&api.GroupRequest{Name: "eventTestGroup"}) + if err != nil { + t.Fatalf("Failed to marshal group request: %v", err) + } + createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId) + createRecorder := httptest.NewRecorder() + apiHandler.ServeHTTP(createRecorder, createReq) + assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event") + + // Now query events + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Event{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group") + + // Verify the group creation event exists + found := false + for _, event := range got { + if event.ActivityCode == "group.add" { + found = true + assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId) + assert.Equal(t, "Group created", event.Activity) + break + } + } + assert.True(t, found, "Expected to find a group.add event") + }) + } +} + +func Test_Events_GetAll_Empty(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + if !expectResponse { + return + } + + got := []api.Event{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 0, len(got), "Expected empty events list when no mutations have been performed") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } +} diff --git a/management/server/http/testing/integration/groups_handler_integration_test.go b/management/server/http/testing/integration/groups_handler_integration_test.go new file mode 100644 index 000000000..edb43f3f3 --- /dev/null +++ b/management/server/http/testing/integration/groups_handler_integration_test.go @@ -0,0 +1,382 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Groups_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all groups", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/groups", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Group{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 2) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Groups_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + groupId string + expectedStatus int + expectGroup bool + }{ + { + name: "Get existing group", + groupId: testing_tools.TestGroupId, + expectedStatus: http.StatusOK, + expectGroup: true, + }, + { + name: "Get non-existing group", + groupId: "nonExistingGroupId", + expectedStatus: http.StatusNotFound, + expectGroup: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectGroup { + got := &api.Group{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.groupId, got.Id) + assert.Equal(t, "testGroupName", got.Name) + assert.Equal(t, 1, got.PeersCount) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Groups_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.GroupRequest + expectedStatus int + verifyResponse func(t *testing.T, group *api.Group) + }{ + { + name: "Create group with valid name", + requestBody: &api.GroupRequest{ + Name: "brandNewGroup", + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.NotEmpty(t, group.Id) + assert.Equal(t, "brandNewGroup", group.Name) + assert.Equal(t, 0, group.PeersCount) + }, + }, + { + name: "Create group with peers", + requestBody: &api.GroupRequest{ + Name: "groupWithPeers", + Peers: &[]string{testing_tools.TestPeerId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.NotEmpty(t, group.Id) + assert.Equal(t, "groupWithPeers", group.Name) + assert.Equal(t, 1, group.PeersCount) + }, + }, + { + name: "Create group with empty name", + requestBody: &api.GroupRequest{ + Name: "", + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/groups", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Group{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify group exists in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id) + assert.Equal(t, tc.requestBody.Name, dbGroup.Name) + } + }) + } + } +} + +func Test_Groups_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + groupId string + requestBody *api.GroupRequest + expectedStatus int + verifyResponse func(t *testing.T, group *api.Group) + }{ + { + name: "Update group name", + groupId: testing_tools.TestGroupId, + requestBody: &api.GroupRequest{ + Name: "updatedGroupName", + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.Equal(t, testing_tools.TestGroupId, group.Id) + assert.Equal(t, "updatedGroupName", group.Name) + }, + }, + { + name: "Update group peers", + groupId: testing_tools.TestGroupId, + requestBody: &api.GroupRequest{ + Name: "testGroupName", + Peers: &[]string{}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, group *api.Group) { + t.Helper() + assert.Equal(t, 0, group.PeersCount) + }, + }, + { + name: "Update with empty name", + groupId: testing_tools.TestGroupId, + requestBody: &api.GroupRequest{ + Name: "", + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Update non-existing group", + groupId: "nonExistingGroupId", + requestBody: &api.GroupRequest{ + Name: "someName", + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Group{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated group in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbGroup := testing_tools.VerifyGroupInDB(t, db, tc.groupId) + assert.Equal(t, tc.requestBody.Name, dbGroup.Name) + } + }) + } + } +} + +func Test_Groups_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + groupId string + expectedStatus int + }{ + { + name: "Delete existing group not in use", + groupId: testing_tools.NewGroupId, + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing group", + groupId: "nonExistingGroupId", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyGroupNotInDB(t, db, tc.groupId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/networks_handler_integration_test.go b/management/server/http/testing/integration/networks_handler_integration_test.go new file mode 100644 index 000000000..4cb6b268b --- /dev/null +++ b/management/server/http/testing/integration/networks_handler_integration_test.go @@ -0,0 +1,1434 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Networks_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all networks", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.Network{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testNetworkId", got[0].Id) + assert.Equal(t, "testNetwork", got[0].Name) + assert.Equal(t, "test network description", *got[0].Description) + assert.GreaterOrEqual(t, len(got[0].Routers), 1) + assert.GreaterOrEqual(t, len(got[0].Resources), 1) + assert.GreaterOrEqual(t, got[0].RoutingPeersCount, 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Networks_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + expectedStatus int + expectNetwork bool + }{ + { + name: "Get existing network", + networkId: "testNetworkId", + expectedStatus: http.StatusOK, + expectNetwork: true, + }, + { + name: "Get non-existing network", + networkId: "nonExistingNetworkId", + expectedStatus: http.StatusNotFound, + expectNetwork: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/networks/{networkId}", "{networkId}", tc.networkId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectNetwork { + got := &api.Network{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.networkId, got.Id) + assert.Equal(t, "testNetwork", got.Name) + assert.Equal(t, "test network description", *got.Description) + assert.GreaterOrEqual(t, len(got.Routers), 1) + assert.GreaterOrEqual(t, len(got.Resources), 1) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Networks_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + desc := "new network description" + + tt := []struct { + name string + requestBody *api.NetworkRequest + expectedStatus int + verifyResponse func(t *testing.T, network *api.Network) + }{ + { + name: "Create network with name and description", + requestBody: &api.NetworkRequest{ + Name: "newNetwork", + Description: &desc, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, network *api.Network) { + t.Helper() + assert.NotEmpty(t, network.Id) + assert.Equal(t, "newNetwork", network.Name) + assert.Equal(t, "new network description", *network.Description) + assert.Empty(t, network.Routers) + assert.Empty(t, network.Resources) + assert.Equal(t, 0, network.RoutingPeersCount) + }, + }, + { + name: "Create network with name only", + requestBody: &api.NetworkRequest{ + Name: "simpleNetwork", + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, network *api.Network) { + t.Helper() + assert.NotEmpty(t, network.Id) + assert.Equal(t, "simpleNetwork", network.Name) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/networks", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Network{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_Networks_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + updatedDesc := "updated description" + + tt := []struct { + name string + networkId string + requestBody *api.NetworkRequest + expectedStatus int + verifyResponse func(t *testing.T, network *api.Network) + }{ + { + name: "Update network name", + networkId: "testNetworkId", + requestBody: &api.NetworkRequest{ + Name: "updatedNetwork", + Description: &updatedDesc, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, network *api.Network) { + t.Helper() + assert.Equal(t, "testNetworkId", network.Id) + assert.Equal(t, "updatedNetwork", network.Name) + assert.Equal(t, "updated description", *network.Description) + }, + }, + { + name: "Update non-existing network", + networkId: "nonExistingNetworkId", + requestBody: &api.NetworkRequest{ + Name: "whatever", + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/networks/{networkId}", "{networkId}", tc.networkId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Network{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_Networks_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + expectedStatus int + }{ + { + name: "Delete existing network", + networkId: "testNetworkId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing network", + networkId: "nonExistingNetworkId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/networks/{networkId}", "{networkId}", tc.networkId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_Networks_Delete_Cascades(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + // Delete the network + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/networks/testNetworkId", testing_tools.TestAdminId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + + // Verify network is gone + req = testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId", testing_tools.TestAdminId) + recorder = httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + testing_tools.ReadResponse(t, recorder, http.StatusNotFound, true) + + // Verify routers in that network are gone + req = testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/routers", testing_tools.TestAdminId) + recorder = httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + content, _ := testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + var routers []*api.NetworkRouter + require.NoError(t, json.Unmarshal(content, &routers)) + assert.Empty(t, routers) + + // Verify resources in that network are gone + req = testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/resources", testing_tools.TestAdminId) + recorder = httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + content, _ = testing_tools.ReadResponse(t, recorder, http.StatusOK, true) + var resources []*api.NetworkResource + require.NoError(t, json.Unmarshal(content, &resources)) + assert.Empty(t, resources) +} + +func Test_NetworkResources_GetAllInNetwork(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all resources in network", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/resources", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkResource{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testResourceId", got[0].Id) + assert.Equal(t, "testResource", got[0].Name) + assert.Equal(t, api.NetworkResourceType("host"), got[0].Type) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkResources_GetAllInAccount(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all resources in account", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/resources", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkResource{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkResources_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + resourceId string + expectedStatus int + expectResource bool + }{ + { + name: "Get existing resource", + networkId: "testNetworkId", + resourceId: "testResourceId", + expectedStatus: http.StatusOK, + expectResource: true, + }, + { + name: "Get non-existing resource", + networkId: "testNetworkId", + resourceId: "nonExistingResourceId", + expectedStatus: http.StatusNotFound, + expectResource: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + path := fmt.Sprintf("/api/networks/%s/resources/%s", tc.networkId, tc.resourceId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectResource { + got := &api.NetworkResource{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.resourceId, got.Id) + assert.Equal(t, "testResource", got.Name) + assert.Equal(t, api.NetworkResourceType("host"), got.Type) + assert.Equal(t, "3.3.3.3/32", got.Address) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_NetworkResources_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + desc := "new resource" + + tt := []struct { + name string + networkId string + requestBody *api.NetworkResourceRequest + expectedStatus int + verifyResponse func(t *testing.T, resource *api.NetworkResource) + }{ + { + name: "Create host resource with IP", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "hostResource", + Description: &desc, + Address: "1.1.1.1", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.NotEmpty(t, resource.Id) + assert.Equal(t, "hostResource", resource.Name) + assert.Equal(t, api.NetworkResourceType("host"), resource.Type) + assert.Equal(t, "1.1.1.1/32", resource.Address) + assert.True(t, resource.Enabled) + }, + }, + { + name: "Create host resource with CIDR /32", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "hostCIDR", + Address: "10.0.0.1/32", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("host"), resource.Type) + assert.Equal(t, "10.0.0.1/32", resource.Address) + }, + }, + { + name: "Create subnet resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "subnetResource", + Address: "192.168.0.0/24", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("subnet"), resource.Type) + assert.Equal(t, "192.168.0.0/24", resource.Address) + }, + }, + { + name: "Create domain resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "domainResource", + Address: "example.com", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("domain"), resource.Type) + assert.Equal(t, "example.com", resource.Address) + }, + }, + { + name: "Create wildcard domain resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "wildcardDomain", + Address: "*.example.com", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("domain"), resource.Type) + assert.Equal(t, "*.example.com", resource.Address) + }, + }, + { + name: "Create disabled resource", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "disabledResource", + Address: "5.5.5.5", + Groups: []string{testing_tools.TestGroupId}, + Enabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.False(t, resource.Enabled) + }, + }, + { + name: "Create resource with invalid address", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "invalidResource", + Address: "not-a-valid-address!!!", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Create resource with empty groups", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "noGroupsResource", + Address: "7.7.7.7", + Groups: []string{}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.NotEmpty(t, resource.Id) + }, + }, + { + name: "Create resource with duplicate name", + networkId: "testNetworkId", + requestBody: &api.NetworkResourceRequest{ + Name: "testResource", + Address: "8.8.8.8", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/resources", tc.networkId) + req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkResource{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkResources_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + updatedDesc := "updated resource" + + tt := []struct { + name string + networkId string + resourceId string + requestBody *api.NetworkResourceRequest + expectedStatus int + verifyResponse func(t *testing.T, resource *api.NetworkResource) + }{ + { + name: "Update resource name and address", + networkId: "testNetworkId", + resourceId: "testResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "updatedResource", + Description: &updatedDesc, + Address: "4.4.4.4", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, "testResourceId", resource.Id) + assert.Equal(t, "updatedResource", resource.Name) + assert.Equal(t, "updated resource", *resource.Description) + assert.Equal(t, "4.4.4.4/32", resource.Address) + }, + }, + { + name: "Update resource to subnet type", + networkId: "testNetworkId", + resourceId: "testResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "testResource", + Address: "10.0.0.0/16", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("subnet"), resource.Type) + assert.Equal(t, "10.0.0.0/16", resource.Address) + }, + }, + { + name: "Update resource to domain type", + networkId: "testNetworkId", + resourceId: "testResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "testResource", + Address: "myservice.example.com", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, resource *api.NetworkResource) { + t.Helper() + assert.Equal(t, api.NetworkResourceType("domain"), resource.Type) + assert.Equal(t, "myservice.example.com", resource.Address) + }, + }, + { + name: "Update non-existing resource", + networkId: "testNetworkId", + resourceId: "nonExistingResourceId", + requestBody: &api.NetworkResourceRequest{ + Name: "whatever", + Address: "1.2.3.4", + Groups: []string{testing_tools.TestGroupId}, + Enabled: true, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/resources/%s", tc.networkId, tc.resourceId) + req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkResource{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkResources_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + resourceId string + expectedStatus int + }{ + { + name: "Delete existing resource", + networkId: "testNetworkId", + resourceId: "testResourceId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing resource", + networkId: "testNetworkId", + resourceId: "nonExistingResourceId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + path := fmt.Sprintf("/api/networks/%s/resources/%s", tc.networkId, tc.resourceId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} + +func Test_NetworkRouters_GetAllInNetwork(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all routers in network", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/testNetworkId/routers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkRouter{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testRouterId", got[0].Id) + assert.Equal(t, "testPeerId", *got[0].Peer) + assert.True(t, got[0].Masquerade) + assert.Equal(t, 100, got[0].Metric) + assert.True(t, got[0].Enabled) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkRouters_GetAllInAccount(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all routers in account", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/networks/routers", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []*api.NetworkRouter{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_NetworkRouters_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + routerId string + expectedStatus int + expectRouter bool + }{ + { + name: "Get existing router", + networkId: "testNetworkId", + routerId: "testRouterId", + expectedStatus: http.StatusOK, + expectRouter: true, + }, + { + name: "Get non-existing router", + networkId: "testNetworkId", + routerId: "nonExistingRouterId", + expectedStatus: http.StatusNotFound, + expectRouter: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, true) + + path := fmt.Sprintf("/api/networks/%s/routers/%s", tc.networkId, tc.routerId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectRouter { + got := &api.NetworkRouter{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.routerId, got.Id) + assert.Equal(t, "testPeerId", *got.Peer) + assert.True(t, got.Masquerade) + assert.Equal(t, 100, got.Metric) + assert.True(t, got.Enabled) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_NetworkRouters_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + peerID := "testPeerId" + peerGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + networkId string + requestBody *api.NetworkRouterRequest + expectedStatus int + verifyResponse func(t *testing.T, router *api.NetworkRouter) + }{ + { + name: "Create router with peer", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 200, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotEmpty(t, router.Id) + assert.Equal(t, peerID, *router.Peer) + assert.True(t, router.Masquerade) + assert.Equal(t, 200, router.Metric) + assert.True(t, router.Enabled) + }, + }, + { + name: "Create router with peer groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + PeerGroups: &peerGroups, + Masquerade: false, + Metric: 300, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotEmpty(t, router.Id) + assert.NotNil(t, router.PeerGroups) + assert.Equal(t, 1, len(*router.PeerGroups)) + assert.False(t, router.Masquerade) + assert.Equal(t, 300, router.Metric) + assert.True(t, router.Enabled) // always true on creation + }, + }, + { + name: "Create router with both peer and peer_groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + PeerGroups: &peerGroups, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotEmpty(t, router.Id) + assert.Equal(t, peerID, *router.Peer) + assert.Equal(t, 1, len(*router.PeerGroups)) + }, + }, + { + name: "Create router in non-existing network", + networkId: "nonExistingNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusNotFound, + }, + { + name: "Create router enabled is always true", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: false, + Metric: 50, + Enabled: false, // handler sets to true + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.True(t, router.Enabled) // always true on creation + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/routers", tc.networkId) + req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkRouter{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkRouters_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + peerID := "testPeerId" + peerGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + networkId string + routerId string + requestBody *api.NetworkRouterRequest + expectedStatus int + verifyResponse func(t *testing.T, router *api.NetworkRouter) + }{ + { + name: "Update router metric and masquerade", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: false, + Metric: 500, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.Equal(t, "testRouterId", router.Id) + assert.False(t, router.Masquerade) + assert.Equal(t, 500, router.Metric) + }, + }, + { + name: "Update router to use peer groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + PeerGroups: &peerGroups, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.NotNil(t, router.PeerGroups) + assert.Equal(t, 1, len(*router.PeerGroups)) + }, + }, + { + name: "Update router disabled", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 100, + Enabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.False(t, router.Enabled) + }, + }, + { + name: "Update non-existing router creates it", + networkId: "testNetworkId", + routerId: "nonExistingRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.Equal(t, "nonExistingRouterId", router.Id) + }, + }, + { + name: "Update router with both peer and peer_groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Peer: &peerID, + PeerGroups: &peerGroups, + Masquerade: true, + Metric: 100, + Enabled: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, router *api.NetworkRouter) { + t.Helper() + assert.Equal(t, "testRouterId", router.Id) + assert.Equal(t, peerID, *router.Peer) + assert.Equal(t, 1, len(*router.PeerGroups)) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + path := fmt.Sprintf("/api/networks/%s/routers/%s", tc.networkId, tc.routerId) + req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.NetworkRouter{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + }) + } + } +} + +func Test_NetworkRouters_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + networkId string + routerId string + expectedStatus int + }{ + { + name: "Delete existing router", + networkId: "testNetworkId", + routerId: "testRouterId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing router", + networkId: "testNetworkId", + routerId: "nonExistingRouterId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/networks.sql", nil, false) + + path := fmt.Sprintf("/api/networks/%s/routers/%s", tc.networkId, tc.routerId) + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + }) + } + } +} diff --git a/management/server/http/testing/integration/peers_handler_integration_test.go b/management/server/http/testing/integration/peers_handler_integration_test.go new file mode 100644 index 000000000..17a9e94a6 --- /dev/null +++ b/management/server/http/testing/integration/peers_handler_integration_test.go @@ -0,0 +1,605 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +const ( + testPeerId2 = "testPeerId2" +) + +func Test_Peers_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: true, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + for _, user := range users { + t.Run(user.name+" - Get all peers", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/peers", user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + var got []api.PeerBatch + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 2, "Expected at least 2 peers") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Peers_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: true, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestType string + requestPath string + requestId string + verifyResponse func(t *testing.T, peer *api.Peer) + }{ + { + name: "Get existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, "test-peer-1", peer.Name) + assert.Equal(t, "test-host-1", peer.Hostname) + assert.Equal(t, "Debian GNU/Linux ", peer.Os) + assert.Equal(t, "0.12.0", peer.Version) + assert.Equal(t, false, peer.SshEnabled) + assert.Equal(t, true, peer.LoginExpirationEnabled) + }, + }, + { + name: "Get second existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}", + requestId: testPeerId2, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testPeerId2, peer.Id) + assert.Equal(t, "test-peer-2", peer.Name) + assert.Equal(t, "test-host-2", peer.Hostname) + assert.Equal(t, "Ubuntu ", peer.Os) + assert.Equal(t, true, peer.SshEnabled) + assert.Equal(t, false, peer.LoginExpirationEnabled) + assert.Equal(t, true, peer.Connected) + }, + }, + { + name: "Get non-existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}", + requestId: "nonExistingPeerId", + expectedStatus: http.StatusNotFound, + verifyResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Peer{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Peers_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestBody *api.PeerRequest + requestType string + requestPath string + requestId string + verifyResponse func(t *testing.T, peer *api.Peer) + }{ + { + name: "Update peer name", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + requestBody: &api.PeerRequest{ + Name: "updated-peer-name", + SshEnabled: false, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, "updated-peer-name", peer.Name) + assert.Equal(t, false, peer.SshEnabled) + assert.Equal(t, true, peer.LoginExpirationEnabled) + }, + }, + { + name: "Enable SSH on peer", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + requestBody: &api.PeerRequest{ + Name: "test-peer-1", + SshEnabled: true, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, "test-peer-1", peer.Name) + assert.Equal(t, true, peer.SshEnabled) + assert.Equal(t, true, peer.LoginExpirationEnabled) + }, + }, + { + name: "Disable login expiration on peer", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: testing_tools.TestPeerId, + requestBody: &api.PeerRequest{ + Name: "test-peer-1", + SshEnabled: false, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, peer *api.Peer) { + t.Helper() + assert.Equal(t, testing_tools.TestPeerId, peer.Id) + assert.Equal(t, false, peer.LoginExpirationEnabled) + }, + }, + { + name: "Update non-existing peer", + requestType: http.MethodPut, + requestPath: "/api/peers/{peerId}", + requestId: "nonExistingPeerId", + requestBody: &api.PeerRequest{ + Name: "updated-name", + SshEnabled: false, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: false, + }, + expectedStatus: http.StatusNotFound, + verifyResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Peer{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated peer in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbPeer := testing_tools.VerifyPeerInDB(t, db, tc.requestId) + assert.Equal(t, tc.requestBody.Name, dbPeer.Name) + assert.Equal(t, tc.requestBody.SshEnabled, dbPeer.SSHEnabled) + assert.Equal(t, tc.requestBody.LoginExpirationEnabled, dbPeer.LoginExpirationEnabled) + assert.Equal(t, tc.requestBody.InactivityExpirationEnabled, dbPeer.InactivityExpirationEnabled) + } + }) + } + } +} + +func Test_Peers_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestType string + requestPath string + requestId string + }{ + { + name: "Delete existing peer", + requestType: http.MethodDelete, + requestPath: "/api/peers/{peerId}", + requestId: testPeerId2, + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing peer", + requestType: http.MethodDelete, + requestPath: "/api/peers/{peerId}", + requestId: "nonExistingPeerId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + // Verify peer is actually deleted in DB + if tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPeerNotInDB(t, db, tc.requestId) + } + }) + } + } +} + +func Test_Peers_GetAccessiblePeers(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + requestType string + requestPath string + requestId string + }{ + { + name: "Get accessible peers for existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}/accessible-peers", + requestId: testing_tools.TestPeerId, + expectedStatus: http.StatusOK, + }, + { + name: "Get accessible peers for non-existing peer", + requestType: http.MethodGet, + requestPath: "/api/peers/{peerId}/accessible-peers", + requestId: "nonExistingPeerId", + expectedStatus: http.StatusOK, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId) + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectedStatus == http.StatusOK { + var got []api.AccessiblePeer + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + // The accessible peers list should be a valid array (may be empty if no policies connect peers) + assert.NotNil(t, got, "Expected accessible peers to be a valid array") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} diff --git a/management/server/http/testing/integration/policies_handler_integration_test.go b/management/server/http/testing/integration/policies_handler_integration_test.go new file mode 100644 index 000000000..6f3624fb5 --- /dev/null +++ b/management/server/http/testing/integration/policies_handler_integration_test.go @@ -0,0 +1,488 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Policies_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all policies", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/policies", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Policy{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "testPolicy", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Policies_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + policyId string + expectedStatus int + expectPolicy bool + }{ + { + name: "Get existing policy", + policyId: "testPolicyId", + expectedStatus: http.StatusOK, + expectPolicy: true, + }, + { + name: "Get non-existing policy", + policyId: "nonExistingPolicyId", + expectedStatus: http.StatusNotFound, + expectPolicy: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectPolicy { + got := &api.Policy{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.NotNil(t, got.Id) + assert.Equal(t, tc.policyId, *got.Id) + assert.Equal(t, "testPolicy", got.Name) + assert.Equal(t, true, got.Enabled) + assert.GreaterOrEqual(t, len(got.Rules), 1) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Policies_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + srcGroups := []string{testing_tools.TestGroupId} + dstGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + requestBody *api.PolicyCreate + expectedStatus int + verifyResponse func(t *testing.T, policy *api.Policy) + }{ + { + name: "Create policy with accept rule", + requestBody: &api.PolicyCreate{ + Name: "newPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "allowAll", + Enabled: true, + Action: "accept", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.NotNil(t, policy.Id) + assert.Equal(t, "newPolicy", policy.Name) + assert.Equal(t, true, policy.Enabled) + assert.Equal(t, 1, len(policy.Rules)) + assert.Equal(t, "allowAll", policy.Rules[0].Name) + }, + }, + { + name: "Create policy with drop rule", + requestBody: &api.PolicyCreate{ + Name: "dropPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "dropAll", + Enabled: true, + Action: "drop", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, "dropPolicy", policy.Name) + }, + }, + { + name: "Create policy with TCP rule and ports", + requestBody: &api.PolicyCreate{ + Name: "tcpPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "tcpRule", + Enabled: true, + Action: "accept", + Protocol: "tcp", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + Ports: &[]string{"80", "443"}, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, "tcpPolicy", policy.Name) + assert.NotNil(t, policy.Rules[0].Ports) + assert.Equal(t, 2, len(*policy.Rules[0].Ports)) + }, + }, + { + name: "Create policy with empty name", + requestBody: &api.PolicyCreate{ + Name: "", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "rule", + Enabled: true, + Action: "accept", + Protocol: "all", + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create policy with no rules", + requestBody: &api.PolicyCreate{ + Name: "noRulesPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/policies", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Policy{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify policy exists in DB with correct fields + db := testing_tools.GetDB(t, am.GetStore()) + dbPolicy := testing_tools.VerifyPolicyInDB(t, db, *got.Id) + assert.Equal(t, tc.requestBody.Name, dbPolicy.Name) + assert.Equal(t, tc.requestBody.Enabled, dbPolicy.Enabled) + assert.Equal(t, len(tc.requestBody.Rules), len(dbPolicy.Rules)) + } + }) + } + } +} + +func Test_Policies_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + srcGroups := []string{testing_tools.TestGroupId} + dstGroups := []string{testing_tools.TestGroupId} + + tt := []struct { + name string + policyId string + requestBody *api.PolicyCreate + expectedStatus int + verifyResponse func(t *testing.T, policy *api.Policy) + }{ + { + name: "Update policy name", + policyId: "testPolicyId", + requestBody: &api.PolicyCreate{ + Name: "updatedPolicy", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "testRule", + Enabled: true, + Action: "accept", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, "updatedPolicy", policy.Name) + }, + }, + { + name: "Update policy enabled state", + policyId: "testPolicyId", + requestBody: &api.PolicyCreate{ + Name: "testPolicy", + Enabled: false, + Rules: []api.PolicyRuleUpdate{ + { + Name: "testRule", + Enabled: true, + Action: "accept", + Protocol: "all", + Bidirectional: true, + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, policy *api.Policy) { + t.Helper() + assert.Equal(t, false, policy.Enabled) + }, + }, + { + name: "Update non-existing policy", + policyId: "nonExistingPolicyId", + requestBody: &api.PolicyCreate{ + Name: "whatever", + Enabled: true, + Rules: []api.PolicyRuleUpdate{ + { + Name: "rule", + Enabled: true, + Action: "accept", + Protocol: "all", + Sources: &srcGroups, + Destinations: &dstGroups, + }, + }, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Policy{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated policy in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbPolicy := testing_tools.VerifyPolicyInDB(t, db, tc.policyId) + assert.Equal(t, tc.requestBody.Name, dbPolicy.Name) + assert.Equal(t, tc.requestBody.Enabled, dbPolicy.Enabled) + } + }) + } + } +} + +func Test_Policies_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + policyId string + expectedStatus int + }{ + { + name: "Delete existing policy", + policyId: "testPolicyId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing policy", + policyId: "nonExistingPolicyId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPolicyNotInDB(t, db, tc.policyId) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/routes_handler_integration_test.go b/management/server/http/testing/integration/routes_handler_integration_test.go new file mode 100644 index 000000000..eeb0c3025 --- /dev/null +++ b/management/server/http/testing/integration/routes_handler_integration_test.go @@ -0,0 +1,455 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Routes_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all routes", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/routes", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.Route{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 2, len(got)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Routes_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + routeId string + expectedStatus int + expectRoute bool + }{ + { + name: "Get existing route", + routeId: "testRouteId", + expectedStatus: http.StatusOK, + expectRoute: true, + }, + { + name: "Get non-existing route", + routeId: "nonExistingRouteId", + expectedStatus: http.StatusNotFound, + expectRoute: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectRoute { + got := &api.Route{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.routeId, got.Id) + assert.Equal(t, "Test Network Route", got.Description) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Routes_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + networkCIDR := "10.10.0.0/24" + peerID := testing_tools.TestPeerId + peerGroups := []string{"peerGroupId"} + + tt := []struct { + name string + requestBody *api.RouteRequest + expectedStatus int + verifyResponse func(t *testing.T, route *api.Route) + }{ + { + name: "Create network route with peer", + requestBody: &api.RouteRequest{ + Description: "New network route", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "newNet", + Metric: 100, + Masquerade: true, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.NotEmpty(t, route.Id) + assert.Equal(t, "New network route", route.Description) + assert.Equal(t, 100, route.Metric) + assert.Equal(t, true, route.Masquerade) + assert.Equal(t, true, route.Enabled) + }, + }, + { + name: "Create network route with peer groups", + requestBody: &api.RouteRequest{ + Description: "Route with peer groups", + Network: &networkCIDR, + PeerGroups: &peerGroups, + NetworkId: "peerGroupNet", + Metric: 150, + Masquerade: false, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.NotEmpty(t, route.Id) + assert.Equal(t, "Route with peer groups", route.Description) + }, + }, + { + name: "Create route with empty network_id", + requestBody: &api.RouteRequest{ + Description: "Empty net id", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "", + Metric: 100, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create route with metric 0", + requestBody: &api.RouteRequest{ + Description: "Zero metric", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "zeroMetric", + Metric: 0, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create route with metric 10000", + requestBody: &api.RouteRequest{ + Description: "High metric", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "highMetric", + Metric: 10000, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/routes", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Route{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify route exists in DB with correct fields + db := testing_tools.GetDB(t, am.GetStore()) + dbRoute := testing_tools.VerifyRouteInDB(t, db, route.ID(got.Id)) + assert.Equal(t, tc.requestBody.Description, dbRoute.Description) + assert.Equal(t, tc.requestBody.Metric, dbRoute.Metric) + assert.Equal(t, tc.requestBody.Masquerade, dbRoute.Masquerade) + assert.Equal(t, tc.requestBody.Enabled, dbRoute.Enabled) + assert.Equal(t, route.NetID(tc.requestBody.NetworkId), dbRoute.NetID) + } + }) + } + } +} + +func Test_Routes_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + networkCIDR := "10.0.0.0/24" + peerID := testing_tools.TestPeerId + + tt := []struct { + name string + routeId string + requestBody *api.RouteRequest + expectedStatus int + verifyResponse func(t *testing.T, route *api.Route) + }{ + { + name: "Update route description", + routeId: "testRouteId", + requestBody: &api.RouteRequest{ + Description: "Updated description", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "testNet", + Metric: 100, + Masquerade: true, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.Equal(t, "testRouteId", route.Id) + assert.Equal(t, "Updated description", route.Description) + }, + }, + { + name: "Update route metric", + routeId: "testRouteId", + requestBody: &api.RouteRequest{ + Description: "Test Network Route", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "testNet", + Metric: 500, + Masquerade: true, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, route *api.Route) { + t.Helper() + assert.Equal(t, 500, route.Metric) + }, + }, + { + name: "Update non-existing route", + routeId: "nonExistingRouteId", + requestBody: &api.RouteRequest{ + Description: "whatever", + Network: &networkCIDR, + Peer: &peerID, + NetworkId: "testNet", + Metric: 100, + Enabled: true, + Groups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.Route{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated route in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbRoute := testing_tools.VerifyRouteInDB(t, db, route.ID(got.Id)) + assert.Equal(t, tc.requestBody.Description, dbRoute.Description) + assert.Equal(t, tc.requestBody.Metric, dbRoute.Metric) + assert.Equal(t, tc.requestBody.Masquerade, dbRoute.Masquerade) + assert.Equal(t, tc.requestBody.Enabled, dbRoute.Enabled) + } + }) + } + } +} + +func Test_Routes_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + routeId string + expectedStatus int + }{ + { + name: "Delete existing route", + routeId: "testRouteId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing route", + routeId: "nonExistingRouteId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify route was deleted from DB for successful deletes + if tc.expectedStatus == http.StatusOK && user.expectResponse { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyRouteNotInDB(t, db, route.ID(tc.routeId)) + } + }) + } + } +} diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index c1a9829da..0d3aaac82 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -3,7 +3,6 @@ package integration import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -14,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/http/api" @@ -254,7 +252,7 @@ func Test_SetupKeys_Create(t *testing.T) { expectedResponse: nil, }, { - name: "Create Setup Key", + name: "Create Setup Key with nil AutoGroups", requestType: http.MethodPost, requestPath: "/api/setup-keys", requestBody: &api.CreateSetupKeyRequest{ @@ -308,14 +306,15 @@ func Test_SetupKeys_Create(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + gotID := got.Id validateCreatedKey(t, tc.expectedResponse, got) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + // Verify setup key exists in DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, tc.expectedResponse.Name, dbKey.Name) + assert.Equal(t, tc.expectedResponse.Revoked, dbKey.Revoked) + assert.Equal(t, tc.expectedResponse.UsageLimit, dbKey.UsageLimit) select { case <-done: @@ -571,7 +570,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { - t.Run(tc.name, func(t *testing.T) { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) @@ -594,14 +593,16 @@ func Test_SetupKeys_Update(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + gotID := got.Id + gotRevoked := got.Revoked + gotUsageLimit := got.UsageLimit validateCreatedKey(t, tc.expectedResponse, got) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + // Verify updated setup key in DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, gotRevoked, dbKey.Revoked) + assert.Equal(t, gotUsageLimit, dbKey.UsageLimit) select { case <-done: @@ -759,8 +760,8 @@ func Test_SetupKeys_Get(t *testing.T) { apiHandler.ServeHTTP(recorder, req) - content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) - if !expectRespnose { + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { return } got := &api.SetupKey{} @@ -768,14 +769,16 @@ func Test_SetupKeys_Get(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + gotID := got.Id + gotName := got.Name + gotRevoked := got.Revoked validateCreatedKey(t, tc.expectedResponse, got) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + // Verify setup key in DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, gotName, dbKey.Name) + assert.Equal(t, gotRevoked, dbKey.Revoked) select { case <-done: @@ -928,15 +931,17 @@ func Test_SetupKeys_GetAll(t *testing.T) { return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit }) + db := testing_tools.GetDB(t, am.GetStore()) for i := range tc.expectedResponse { + gotID := got[i].Id + gotName := got[i].Name + gotRevoked := got[i].Revoked validateCreatedKey(t, tc.expectedResponse[i], &got[i]) - key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) - if err != nil { - return - } - - validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key)) + // Verify each setup key in DB via gorm + dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID) + assert.Equal(t, gotName, dbKey.Name) + assert.Equal(t, gotRevoked, dbKey.Revoked) } select { @@ -1104,8 +1109,9 @@ func Test_SetupKeys_Delete(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } - _, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) - assert.Errorf(t, err, "Expected error when trying to get deleted key") + // Verify setup key deleted from DB via gorm + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifySetupKeyNotInDB(t, db, got.Id) select { case <-done: @@ -1120,7 +1126,7 @@ func Test_SetupKeys_Delete(t *testing.T) { func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) { t.Helper() - if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) || + if (got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second))) || got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) || got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) { got.Expires = time.Time{} diff --git a/management/server/http/testing/integration/users_handler_integration_test.go b/management/server/http/testing/integration/users_handler_integration_test.go new file mode 100644 index 000000000..eae3b4ad5 --- /dev/null +++ b/management/server/http/testing/integration/users_handler_integration_test.go @@ -0,0 +1,701 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func Test_Users_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, true}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, true}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all users", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.User{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.GreaterOrEqual(t, len(got), 1) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Users_GetAll_ServiceUsers(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all service users", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users?service_user=true", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.User{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + for _, u := range got { + assert.NotNil(t, u.IsServiceUser) + assert.Equal(t, true, *u.IsServiceUser) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_Users_Create_ServiceUser(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + requestBody *api.UserCreateRequest + expectedStatus int + verifyResponse func(t *testing.T, user *api.User) + }{ + { + name: "Create service user with admin role", + requestBody: &api.UserCreateRequest{ + Role: "admin", + IsServiceUser: true, + AutoGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.NotEmpty(t, user.Id) + assert.Equal(t, "admin", user.Role) + assert.NotNil(t, user.IsServiceUser) + assert.Equal(t, true, *user.IsServiceUser) + }, + }, + { + name: "Create service user with user role", + requestBody: &api.UserCreateRequest{ + Role: "user", + IsServiceUser: true, + AutoGroups: []string{testing_tools.TestGroupId}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.NotEmpty(t, user.Id) + assert.Equal(t, "user", user.Role) + }, + }, + { + name: "Create service user with empty auto_groups", + requestBody: &api.UserCreateRequest{ + Role: "admin", + IsServiceUser: true, + AutoGroups: []string{}, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.NotEmpty(t, user.Id) + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users", user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.User{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify user in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbUser := testing_tools.VerifyUserInDB(t, db, got.Id) + assert.True(t, dbUser.IsServiceUser) + assert.Equal(t, string(dbUser.Role), string(tc.requestBody.Role)) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_Users_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + requestBody *api.UserRequest + expectedStatus int + verifyResponse func(t *testing.T, user *api.User) + }{ + { + name: "Update user role to admin", + targetUserId: testing_tools.TestUserId, + requestBody: &api.UserRequest{ + Role: "admin", + AutoGroups: []string{}, + IsBlocked: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.Equal(t, "admin", user.Role) + }, + }, + { + name: "Update user auto_groups", + targetUserId: testing_tools.TestUserId, + requestBody: &api.UserRequest{ + Role: "user", + AutoGroups: []string{testing_tools.TestGroupId}, + IsBlocked: false, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.Equal(t, 1, len(user.AutoGroups)) + }, + }, + { + name: "Block user", + targetUserId: testing_tools.TestUserId, + requestBody: &api.UserRequest{ + Role: "user", + AutoGroups: []string{}, + IsBlocked: true, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, user *api.User) { + t.Helper() + assert.Equal(t, true, user.IsBlocked) + }, + }, + { + name: "Update non-existing user", + targetUserId: "nonExistingUserId", + requestBody: &api.UserRequest{ + Role: "user", + AutoGroups: []string{}, + IsBlocked: false, + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/users/{userId}", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.User{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify updated fields in DB + if tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + dbUser := testing_tools.VerifyUserInDB(t, db, tc.targetUserId) + assert.Equal(t, string(dbUser.Role), string(tc.requestBody.Role)) + assert.Equal(t, dbUser.Blocked, tc.requestBody.IsBlocked) + assert.ElementsMatch(t, dbUser.AutoGroups, tc.requestBody.AutoGroups) + } + } + }) + } + } +} + +func Test_Users_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + expectedStatus int + }{ + { + name: "Delete existing service user", + targetUserId: "deletableServiceUserId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing user", + targetUserId: "nonExistingUserId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify user deleted from DB for successful deletes + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PATs_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + for _, user := range users { + t.Run(user.name+" - Get all PATs for service user", func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/users/{userId}/tokens", "{userId}", testing_tools.TestServiceUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse) + if !expectResponse { + return + } + + got := []api.PersonalAccessToken{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, 1, len(got)) + assert.Equal(t, "serviceToken", got[0].Name) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } +} + +func Test_PATs_GetById(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + tokenId string + expectedStatus int + expectToken bool + }{ + { + name: "Get existing PAT", + tokenId: "serviceTokenId", + expectedStatus: http.StatusOK, + expectToken: true, + }, + { + name: "Get non-existing PAT", + tokenId: "nonExistingTokenId", + expectedStatus: http.StatusNotFound, + expectToken: false, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + path := strings.Replace("/api/users/{userId}/tokens/{tokenId}", "{userId}", testing_tools.TestServiceUserId, 1) + path = strings.Replace(path, "{tokenId}", tc.tokenId, 1) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.expectToken { + got := &api.PersonalAccessToken{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, "serviceTokenId", got.Id) + assert.Equal(t, "serviceToken", got.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PATs_Create(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + targetUserId string + requestBody *api.PersonalAccessTokenRequest + expectedStatus int + verifyResponse func(t *testing.T, pat *api.PersonalAccessTokenGenerated) + }{ + { + name: "Create PAT with 30 day expiry", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "newPAT", + ExpiresIn: 30, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, pat *api.PersonalAccessTokenGenerated) { + t.Helper() + assert.NotEmpty(t, pat.PlainToken) + assert.Equal(t, "newPAT", pat.PersonalAccessToken.Name) + }, + }, + { + name: "Create PAT with 365 day expiry", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "longPAT", + ExpiresIn: 365, + }, + expectedStatus: http.StatusOK, + verifyResponse: func(t *testing.T, pat *api.PersonalAccessTokenGenerated) { + t.Helper() + assert.NotEmpty(t, pat.PlainToken) + assert.Equal(t, "longPAT", pat.PersonalAccessToken.Name) + }, + }, + { + name: "Create PAT with empty name", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "", + ExpiresIn: 30, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create PAT with 0 day expiry", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "zeroPAT", + ExpiresIn: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + { + name: "Create PAT with expiry over 365 days", + targetUserId: testing_tools.TestServiceUserId, + requestBody: &api.PersonalAccessTokenRequest{ + Name: "tooLongPAT", + ExpiresIn: 400, + }, + expectedStatus: http.StatusUnprocessableEntity, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, http.MethodPost, strings.Replace("/api/users/{userId}/tokens", "{userId}", tc.targetUserId, 1), user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + + if tc.verifyResponse != nil { + got := &api.PersonalAccessTokenGenerated{} + if err := json.Unmarshal(content, got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + tc.verifyResponse(t, got) + + // Verify PAT in DB + db := testing_tools.GetDB(t, am.GetStore()) + dbPAT := testing_tools.VerifyPATInDB(t, db, got.PersonalAccessToken.Id) + assert.Equal(t, tc.requestBody.Name, dbPAT.Name) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_PATs_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + {"Regular user", testing_tools.TestUserId, false}, + {"Admin user", testing_tools.TestAdminId, true}, + {"Owner user", testing_tools.TestOwnerId, true}, + {"Regular service user", testing_tools.TestServiceUserId, false}, + {"Admin service user", testing_tools.TestServiceAdminId, true}, + {"Blocked user", testing_tools.BlockedUserId, false}, + {"Other user", testing_tools.OtherUserId, false}, + {"Invalid token", testing_tools.InvalidToken, false}, + } + + tt := []struct { + name string + tokenId string + expectedStatus int + }{ + { + name: "Delete existing PAT", + tokenId: "serviceTokenId", + expectedStatus: http.StatusOK, + }, + { + name: "Delete non-existing PAT", + tokenId: "nonExistingTokenId", + expectedStatus: http.StatusNotFound, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true) + + path := strings.Replace("/api/users/{userId}/tokens/{tokenId}", "{userId}", testing_tools.TestServiceUserId, 1) + path = strings.Replace(path, "{tokenId}", tc.tokenId, 1) + + req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId) + recorder := httptest.NewRecorder() + apiHandler.ServeHTTP(recorder, req) + + _, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + + // Verify PAT deleted from DB for successful deletes + if expectResponse && tc.expectedStatus == http.StatusOK { + db := testing_tools.GetDB(t, am.GetStore()) + testing_tools.VerifyPATNotInDB(t, db, tc.tokenId) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} diff --git a/management/server/http/testing/testdata/accounts.sql b/management/server/http/testing/testdata/accounts.sql new file mode 100644 index 000000000..35f00d419 --- /dev/null +++ b/management/server/http/testing/testdata/accounts.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); diff --git a/management/server/http/testing/testdata/dns.sql b/management/server/http/testing/testdata/dns.sql new file mode 100644 index 000000000..9ed4daf7e --- /dev/null +++ b/management/server/http/testing/testdata/dns.sql @@ -0,0 +1,21 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/events.sql b/management/server/http/testing/testdata/events.sql new file mode 100644 index 000000000..27fd01aea --- /dev/null +++ b/management/server/http/testing/testdata/events.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/groups.sql b/management/server/http/testing/testdata/groups.sql new file mode 100644 index 000000000..eb874f036 --- /dev/null +++ b/management/server/http/testing/testdata/groups.sql @@ -0,0 +1,19 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/networks.sql b/management/server/http/testing/testdata/networks.sql new file mode 100644 index 000000000..39ec8e646 --- /dev/null +++ b/management/server/http/testing/testdata/networks.sql @@ -0,0 +1,25 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`domain` text,`prefix` text,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO networks VALUES('testNetworkId','testAccountId','testNetwork','test network description'); +INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','testPeerId','[]',1,100,1); +INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','testResource','test resource description','host','','"3.3.3.3/32"',1); \ No newline at end of file diff --git a/management/server/http/testing/testdata/peers_integration.sql b/management/server/http/testing/testdata/peers_integration.sql new file mode 100644 index 000000000..62a7760e7 --- /dev/null +++ b/management/server/http/testing/testdata/peers_integration.sql @@ -0,0 +1,20 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId","testPeerId2"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); + +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','test-host-1','linux','Linux','','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-1','test-peer-1','2023-03-02 09:21:02.189035775+01:00',0,0,0,'testUserId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('testPeerId2','testAccountId','6rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYBg=','82546A29-6BC8-4311-BCFC-9CDBF33F1A49','"100.64.114.32"','test-host-2','linux','Linux','','unknown','Ubuntu','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-2','test-peer-2','2023-03-02 09:21:02.189035775+01:00',1,0,0,'testAdminId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); \ No newline at end of file diff --git a/management/server/http/testing/testdata/policies.sql b/management/server/http/testing/testdata/policies.sql new file mode 100644 index 000000000..7e6cc883b --- /dev/null +++ b/management/server/http/testing/testdata/policies.sql @@ -0,0 +1,23 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`protocol` text,`bidirectional` numeric,`sources` text,`destinations` text,`source_resource` text,`destination_resource` text,`ports` text,`port_ranges` text,`authorized_groups` text,`authorized_user` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules_g` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO policies VALUES('testPolicyId','testAccountId','testPolicy','test policy description',1,NULL); +INSERT INTO policy_rules VALUES('testRuleId','testPolicyId','testRule','test rule',1,'accept','all',1,'["testGroupId"]','["testGroupId"]',NULL,NULL,NULL,NULL,NULL,''); \ No newline at end of file diff --git a/management/server/http/testing/testdata/routes.sql b/management/server/http/testing/testdata/routes.sql new file mode 100644 index 000000000..48aa02052 --- /dev/null +++ b/management/server/http/testing/testdata/routes.sql @@ -0,0 +1,23 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,`skip_auto_apply` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('peerGroupId','testAccountId','peerGroupName','api','["testPeerId"]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO routes VALUES('testRouteId','testAccountId','"10.0.0.0/24"',NULL,0,'testNet','Test Network Route','testPeerId',NULL,1,1,100,1,'["testGroupId"]',NULL,0); +INSERT INTO routes VALUES('testDomainRouteId','testAccountId','"0.0.0.0/0"','["example.com"]',0,'testDomainNet','Test Domain Route','','["peerGroupId"]',3,1,200,1,'["testGroupId"]',NULL,0); diff --git a/management/server/http/testing/testdata/users_integration.sql b/management/server/http/testing/testdata/users_integration.sql new file mode 100644 index 000000000..57df73e8c --- /dev/null +++ b/management/server/http/testing/testdata/users_integration.sql @@ -0,0 +1,24 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime DEFAULT NULL,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('deletableServiceUserId','testAccountId','user',1,0,'deletableServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +INSERT INTO personal_access_tokens VALUES('testTokenId','testUserId','testToken','hashedTokenValue123','2325-10-02 16:01:38.000000000+00:00','testUserId','2024-10-02 16:01:38.000000000+00:00',NULL); +INSERT INTO personal_access_tokens VALUES('serviceTokenId','testServiceUserId','serviceToken','hashedServiceTokenValue123','2325-10-02 16:01:38.000000000+00:00','testAdminId','2024-10-02 16:01:38.000000000+00:00',NULL); \ No newline at end of file diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 6bd269a2c..c6e57b1be 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -114,8 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - domainManager.SetClusterCapabilities(serviceProxyController) - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) @@ -128,14 +127,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee GetPATInfoFunc: authManager.GetPATInfo, } - networksManagerMock := networks.NewManagerMock() - resourcesManagerMock := resources.NewManagerMock() - routersManagerMock := routers.NewManagerMock() - groupsManagerMock := groups.NewManagerMock() + groupsManager := groups.NewManager(store, permissionsManager, am) + routersManager := routers.NewManager(store, permissionsManager, am) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am) customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -167,6 +166,111 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m } } +// PeerShouldReceiveAnyUpdate waits for a peer update message and returns it. +// Fails the test if no update is received within timeout. +func PeerShouldReceiveAnyUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) *network_map.UpdateMessage { + t.Helper() + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + return msg + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + return nil + } +} + +// PeerShouldNotReceiveAnyUpdate verifies no peer update message is received. +func PeerShouldNotReceiveAnyUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) { + t.Helper() + peerShouldNotReceiveUpdate(t, updateMessage) +} + +// BuildApiBlackBoxWithDBStateAndPeerChannel creates the API handler and returns +// the peer update channel directly so tests can verify updates inline. +func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile string) (http.Handler, account.Manager, <-chan *network_map.UpdateMessage) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := update_channel.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) + + geoMock := &geolocation.Mock{} + validatorMock := server.MockIntegratedValidator{} + proxyController := integrations.NewController(store) + userManager := users.NewManager(store) + permissionsManager := permissions.NewManager(store) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) + peersManager := peers.NewManager(store, permissionsManager) + + jobManager := job.NewJobManager(nil, store, peersManager) + + ctx := context.Background() + requestBuffer := server.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) + proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + if err != nil { + t.Fatalf("Failed to create proxy token store: %v", err) + } + pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + if err != nil { + t.Fatalf("Failed to create PKCE verifier store: %v", err) + } + noopMeter := noop.NewMeterProvider().Meter("") + proxyMgr, err := proxymanager.NewManager(store, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy manager: %v", err) + } + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) + serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy controller: %v", err) + } + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) + proxyServiceServer.SetServiceManager(serviceManager) + am.SetServiceManager(serviceManager) + + // @note this is required so that PAT's validate from store, but JWT's are mocked + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &serverauth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, + MarkPATUsedFunc: authManager.MarkPATUsed, + GetPATInfoFunc: authManager.GetPATInfo, + } + + groupsManager := groups.NewManager(store, permissionsManager, am) + routersManager := routers.NewManager(store, permissionsManager, am) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am) + customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") + zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) + + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, updMsg +} + func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) { userAuth := auth.UserAuth{} diff --git a/management/server/http/testing/testing_tools/db_verify.go b/management/server/http/testing/testing_tools/db_verify.go new file mode 100644 index 000000000..f8af6a41f --- /dev/null +++ b/management/server/http/testing/testing_tools/db_verify.go @@ -0,0 +1,222 @@ +package testing_tools + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// GetDB extracts the *gorm.DB from a store.Store (must be *SqlStore). +func GetDB(t *testing.T, s store.Store) *gorm.DB { + t.Helper() + sqlStore, ok := s.(*store.SqlStore) + require.True(t, ok, "Store is not a *SqlStore, cannot get gorm.DB") + return sqlStore.GetDB() +} + +// VerifyGroupInDB reads a group directly from the DB and returns it. +func VerifyGroupInDB(t *testing.T, db *gorm.DB, groupID string) *types.Group { + t.Helper() + var group types.Group + err := db.Where("id = ? AND account_id = ?", groupID, TestAccountId).First(&group).Error + require.NoError(t, err, "Expected group %s to exist in DB", groupID) + return &group +} + +// VerifyGroupNotInDB verifies that a group does not exist in the DB. +func VerifyGroupNotInDB(t *testing.T, db *gorm.DB, groupID string) { + t.Helper() + var count int64 + db.Model(&types.Group{}).Where("id = ? AND account_id = ?", groupID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected group %s to NOT exist in DB", groupID) +} + +// VerifyPolicyInDB reads a policy directly from the DB and returns it. +func VerifyPolicyInDB(t *testing.T, db *gorm.DB, policyID string) *types.Policy { + t.Helper() + var policy types.Policy + err := db.Preload("Rules").Where("id = ? AND account_id = ?", policyID, TestAccountId).First(&policy).Error + require.NoError(t, err, "Expected policy %s to exist in DB", policyID) + return &policy +} + +// VerifyPolicyNotInDB verifies that a policy does not exist in the DB. +func VerifyPolicyNotInDB(t *testing.T, db *gorm.DB, policyID string) { + t.Helper() + var count int64 + db.Model(&types.Policy{}).Where("id = ? AND account_id = ?", policyID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected policy %s to NOT exist in DB", policyID) +} + +// VerifyRouteInDB reads a route directly from the DB and returns it. +func VerifyRouteInDB(t *testing.T, db *gorm.DB, routeID route.ID) *route.Route { + t.Helper() + var r route.Route + err := db.Where("id = ? AND account_id = ?", routeID, TestAccountId).First(&r).Error + require.NoError(t, err, "Expected route %s to exist in DB", routeID) + return &r +} + +// VerifyRouteNotInDB verifies that a route does not exist in the DB. +func VerifyRouteNotInDB(t *testing.T, db *gorm.DB, routeID route.ID) { + t.Helper() + var count int64 + db.Model(&route.Route{}).Where("id = ? AND account_id = ?", routeID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected route %s to NOT exist in DB", routeID) +} + +// VerifyNSGroupInDB reads a nameserver group directly from the DB and returns it. +func VerifyNSGroupInDB(t *testing.T, db *gorm.DB, nsGroupID string) *nbdns.NameServerGroup { + t.Helper() + var nsGroup nbdns.NameServerGroup + err := db.Where("id = ? AND account_id = ?", nsGroupID, TestAccountId).First(&nsGroup).Error + require.NoError(t, err, "Expected NS group %s to exist in DB", nsGroupID) + return &nsGroup +} + +// VerifyNSGroupNotInDB verifies that a nameserver group does not exist in the DB. +func VerifyNSGroupNotInDB(t *testing.T, db *gorm.DB, nsGroupID string) { + t.Helper() + var count int64 + db.Model(&nbdns.NameServerGroup{}).Where("id = ? AND account_id = ?", nsGroupID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected NS group %s to NOT exist in DB", nsGroupID) +} + +// VerifyPeerInDB reads a peer directly from the DB and returns it. +func VerifyPeerInDB(t *testing.T, db *gorm.DB, peerID string) *nbpeer.Peer { + t.Helper() + var peer nbpeer.Peer + err := db.Where("id = ? AND account_id = ?", peerID, TestAccountId).First(&peer).Error + require.NoError(t, err, "Expected peer %s to exist in DB", peerID) + return &peer +} + +// VerifyPeerNotInDB verifies that a peer does not exist in the DB. +func VerifyPeerNotInDB(t *testing.T, db *gorm.DB, peerID string) { + t.Helper() + var count int64 + db.Model(&nbpeer.Peer{}).Where("id = ? AND account_id = ?", peerID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected peer %s to NOT exist in DB", peerID) +} + +// VerifySetupKeyInDB reads a setup key directly from the DB and returns it. +func VerifySetupKeyInDB(t *testing.T, db *gorm.DB, keyID string) *types.SetupKey { + t.Helper() + var key types.SetupKey + err := db.Where("id = ? AND account_id = ?", keyID, TestAccountId).First(&key).Error + require.NoError(t, err, "Expected setup key %s to exist in DB", keyID) + return &key +} + +// VerifySetupKeyNotInDB verifies that a setup key does not exist in the DB. +func VerifySetupKeyNotInDB(t *testing.T, db *gorm.DB, keyID string) { + t.Helper() + var count int64 + db.Model(&types.SetupKey{}).Where("id = ? AND account_id = ?", keyID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected setup key %s to NOT exist in DB", keyID) +} + +// VerifyUserInDB reads a user directly from the DB and returns it. +func VerifyUserInDB(t *testing.T, db *gorm.DB, userID string) *types.User { + t.Helper() + var user types.User + err := db.Where("id = ? AND account_id = ?", userID, TestAccountId).First(&user).Error + require.NoError(t, err, "Expected user %s to exist in DB", userID) + return &user +} + +// VerifyUserNotInDB verifies that a user does not exist in the DB. +func VerifyUserNotInDB(t *testing.T, db *gorm.DB, userID string) { + t.Helper() + var count int64 + db.Model(&types.User{}).Where("id = ? AND account_id = ?", userID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected user %s to NOT exist in DB", userID) +} + +// VerifyPATInDB reads a PAT directly from the DB and returns it. +func VerifyPATInDB(t *testing.T, db *gorm.DB, tokenID string) *types.PersonalAccessToken { + t.Helper() + var pat types.PersonalAccessToken + err := db.Where("id = ?", tokenID).First(&pat).Error + require.NoError(t, err, "Expected PAT %s to exist in DB", tokenID) + return &pat +} + +// VerifyPATNotInDB verifies that a PAT does not exist in the DB. +func VerifyPATNotInDB(t *testing.T, db *gorm.DB, tokenID string) { + t.Helper() + var count int64 + db.Model(&types.PersonalAccessToken{}).Where("id = ?", tokenID).Count(&count) + assert.Equal(t, int64(0), count, "Expected PAT %s to NOT exist in DB", tokenID) +} + +// VerifyAccountSettings reads the account and returns its settings from the DB. +func VerifyAccountSettings(t *testing.T, db *gorm.DB) *types.Account { + t.Helper() + var account types.Account + err := db.Where("id = ?", TestAccountId).First(&account).Error + require.NoError(t, err, "Expected account %s to exist in DB", TestAccountId) + return &account +} + +// VerifyNetworkInDB reads a network directly from the store and returns it. +func VerifyNetworkInDB(t *testing.T, db *gorm.DB, networkID string) *networkTypes.Network { + t.Helper() + var network networkTypes.Network + err := db.Where("id = ? AND account_id = ?", networkID, TestAccountId).First(&network).Error + require.NoError(t, err, "Expected network %s to exist in DB", networkID) + return &network +} + +// VerifyNetworkNotInDB verifies that a network does not exist in the DB. +func VerifyNetworkNotInDB(t *testing.T, db *gorm.DB, networkID string) { + t.Helper() + var count int64 + db.Model(&networkTypes.Network{}).Where("id = ? AND account_id = ?", networkID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected network %s to NOT exist in DB", networkID) +} + +// VerifyNetworkResourceInDB reads a network resource directly from the DB and returns it. +func VerifyNetworkResourceInDB(t *testing.T, db *gorm.DB, resourceID string) *resourceTypes.NetworkResource { + t.Helper() + var resource resourceTypes.NetworkResource + err := db.Where("id = ? AND account_id = ?", resourceID, TestAccountId).First(&resource).Error + require.NoError(t, err, "Expected network resource %s to exist in DB", resourceID) + return &resource +} + +// VerifyNetworkResourceNotInDB verifies that a network resource does not exist in the DB. +func VerifyNetworkResourceNotInDB(t *testing.T, db *gorm.DB, resourceID string) { + t.Helper() + var count int64 + db.Model(&resourceTypes.NetworkResource{}).Where("id = ? AND account_id = ?", resourceID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected network resource %s to NOT exist in DB", resourceID) +} + +// VerifyNetworkRouterInDB reads a network router directly from the DB and returns it. +func VerifyNetworkRouterInDB(t *testing.T, db *gorm.DB, routerID string) *routerTypes.NetworkRouter { + t.Helper() + var router routerTypes.NetworkRouter + err := db.Where("id = ? AND account_id = ?", routerID, TestAccountId).First(&router).Error + require.NoError(t, err, "Expected network router %s to exist in DB", routerID) + return &router +} + +// VerifyNetworkRouterNotInDB verifies that a network router does not exist in the DB. +func VerifyNetworkRouterNotInDB(t *testing.T, db *gorm.DB, routerID string) { + t.Helper() + var count int64 + db.Model(&routerTypes.NetworkRouter{}).Where("id = ? AND account_id = ?", routerID, TestAccountId).Count(&count) + assert.Equal(t, int64(0), count, "Expected network router %s to NOT exist in DB", routerID) +} diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 28e3d81f9..20d6cacd5 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -197,6 +197,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr case "jumpcloud": return NewJumpCloudManager(JumpCloudClientConfig{ APIToken: config.ExtraConfig["ApiToken"], + ApiUrl: config.ExtraConfig["ApiUrl"], }, appMetrics) case "pocketid": return NewPocketIdManager(PocketIdClientConfig{ diff --git a/management/server/idp/jumpcloud.go b/management/server/idp/jumpcloud.go index 8c4a9d089..f0dec3a9b 100644 --- a/management/server/idp/jumpcloud.go +++ b/management/server/idp/jumpcloud.go @@ -1,24 +1,40 @@ package idp import ( + "bytes" "context" + "encoding/json" "fmt" + "io" "net/http" "strings" - v1 "github.com/TheJumpCloud/jcapi-go/v1" - "github.com/netbirdio/netbird/management/server/telemetry" ) const ( - contentType = "application/json" - accept = "application/json" + jumpCloudDefaultApiUrl = "https://console.jumpcloud.com" + jumpCloudSearchPageSize = 100 ) +// jumpCloudUser represents a JumpCloud V1 API system user. +type jumpCloudUser struct { + ID string `json:"_id"` + Email string `json:"email"` + Firstname string `json:"firstname"` + Middlename string `json:"middlename"` + Lastname string `json:"lastname"` +} + +// jumpCloudUserList represents the response from the JumpCloud search endpoint. +type jumpCloudUserList struct { + Results []jumpCloudUser `json:"results"` + TotalCount int `json:"totalCount"` +} + // JumpCloudManager JumpCloud manager client instance. type JumpCloudManager struct { - client *v1.APIClient + apiBase string apiToken string httpClient ManagerHTTPClient credentials ManagerCredentials @@ -29,6 +45,7 @@ type JumpCloudManager struct { // JumpCloudClientConfig JumpCloud manager client configurations. type JumpCloudClientConfig struct { APIToken string + ApiUrl string } // JumpCloudCredentials JumpCloud authentication information. @@ -55,7 +72,15 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing") } - client := v1.NewAPIClient(v1.NewConfiguration()) + apiBase := config.ApiUrl + if apiBase == "" { + apiBase = jumpCloudDefaultApiUrl + } + apiBase = strings.TrimSuffix(apiBase, "/") + if !strings.HasSuffix(apiBase, "/api") { + apiBase += "/api" + } + credentials := &JumpCloudCredentials{ clientConfig: config, httpClient: httpClient, @@ -64,7 +89,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM } return &JumpCloudManager{ - client: client, + apiBase: apiBase, apiToken: config.APIToken, httpClient: httpClient, credentials: credentials, @@ -78,37 +103,58 @@ func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error return JWTToken{}, nil } -func (jm *JumpCloudManager) authenticationContext() context.Context { - return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{ - Key: jm.apiToken, - }) -} - -// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { - return nil -} - -// GetUserDataByID requests user data from JumpCloud via ID. -func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { - authCtx := jm.authenticationContext() - user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil) +// doRequest executes an HTTP request against the JumpCloud V1 API. +func (jm *JumpCloudManager) doRequest(ctx context.Context, method, path string, body io.Reader) ([]byte, error) { + reqURL := jm.apiBase + path + req, err := http.NewRequestWithContext(ctx, method, reqURL, body) if err != nil { return nil, err } + + req.Header.Set("x-api-key", jm.apiToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := jm.httpClient.Do(req) + if err != nil { + if jm.appMetrics != nil { + jm.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) + return nil, fmt.Errorf("JumpCloud API request %s %s failed with status %d", method, path, resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +// GetUserDataByID requests user data from JumpCloud via ID. +func (jm *JumpCloudManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + body, err := jm.doRequest(ctx, http.MethodGet, "/systemusers/"+userID, nil) + if err != nil { + return nil, err } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetUserDataByID() } + var user jumpCloudUser + if err = jm.helper.Unmarshal(body, &user); err != nil { + return nil, err + } + userData := parseJumpCloudUser(user) userData.AppMetadata = appMetadata @@ -116,30 +162,20 @@ func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, ap } // GetAccount returns all the users for a given profile. -func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { - authCtx := jm.authenticationContext() - userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) +func (jm *JumpCloudManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + allUsers, err := jm.searchAllUsers(ctx) if err != nil { return nil, err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode) - } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetAccount() } - users := make([]*UserData, 0) - for _, user := range userList.Results { + users := make([]*UserData, 0, len(allUsers)) + for _, user := range allUsers { userData := parseJumpCloudUser(user) userData.AppMetadata.WTAccountID = accountID - users = append(users, userData) } @@ -148,27 +184,18 @@ func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]* // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { - authCtx := jm.authenticationContext() - userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) +func (jm *JumpCloudManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + allUsers, err := jm.searchAllUsers(ctx) if err != nil { return nil, err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) - } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetAllAccounts() } indexedUsers := make(map[string][]*UserData) - for _, user := range userList.Results { + for _, user := range allUsers { userData := parseJumpCloudUser(user) indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) } @@ -176,6 +203,41 @@ func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*Use return indexedUsers, nil } +// searchAllUsers paginates through all system users using limit/skip. +func (jm *JumpCloudManager) searchAllUsers(ctx context.Context) ([]jumpCloudUser, error) { + var allUsers []jumpCloudUser + + for skip := 0; ; skip += jumpCloudSearchPageSize { + searchReq := map[string]int{ + "limit": jumpCloudSearchPageSize, + "skip": skip, + } + + payload, err := json.Marshal(searchReq) + if err != nil { + return nil, err + } + + body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload)) + if err != nil { + return nil, err + } + + var userList jumpCloudUserList + if err = jm.helper.Unmarshal(body, &userList); err != nil { + return nil, err + } + + allUsers = append(allUsers, userList.Results...) + + if skip+len(userList.Results) >= userList.TotalCount { + break + } + } + + return allUsers, nil +} + // CreateUser creates a new user in JumpCloud Idp and sends an invitation. func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") @@ -183,7 +245,7 @@ func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*U // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { +func (jm *JumpCloudManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { searchFilter := map[string]interface{}{ "searchFilter": map[string]interface{}{ "filter": []string{email}, @@ -191,25 +253,26 @@ func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]* }, } - authCtx := jm.authenticationContext() - userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter) + payload, err := json.Marshal(searchFilter) if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode) + body, err := jm.doRequest(ctx, http.MethodPost, "/search/systemusers", bytes.NewReader(payload)) + if err != nil { + return nil, err } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountGetUserByEmail() } - usersData := make([]*UserData, 0) + var userList jumpCloudUserList + if err = jm.helper.Unmarshal(body, &userList); err != nil { + return nil, err + } + + usersData := make([]*UserData, 0, len(userList.Results)) for _, user := range userList.Results { usersData = append(usersData, parseJumpCloudUser(user)) } @@ -224,20 +287,11 @@ func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error { } // DeleteUser from jumpCloud directory -func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error { - authCtx := jm.authenticationContext() - _, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil) +func (jm *JumpCloudManager) DeleteUser(ctx context.Context, userID string) error { + _, err := jm.doRequest(ctx, http.MethodDelete, "/systemusers/"+userID, nil) if err != nil { return err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if jm.appMetrics != nil { - jm.appMetrics.IDPMetrics().CountRequestStatusError() - } - return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode) - } if jm.appMetrics != nil { jm.appMetrics.IDPMetrics().CountDeleteUser() @@ -247,11 +301,11 @@ func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error { } // parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData. -func parseJumpCloudUser(user v1.Systemuserreturn) *UserData { +func parseJumpCloudUser(user jumpCloudUser) *UserData { names := []string{user.Firstname, user.Middlename, user.Lastname} return &UserData{ Email: user.Email, Name: strings.Join(names, " "), - ID: user.Id, + ID: user.ID, } } diff --git a/management/server/idp/jumpcloud_test.go b/management/server/idp/jumpcloud_test.go index 1bfdcefcc..dc7a9cb6c 100644 --- a/management/server/idp/jumpcloud_test.go +++ b/management/server/idp/jumpcloud_test.go @@ -1,8 +1,15 @@ package idp import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/telemetry" @@ -44,3 +51,212 @@ func TestNewJumpCloudManager(t *testing.T) { }) } } + +func TestJumpCloudGetUserDataByID(t *testing.T) { + userResponse := jumpCloudUser{ + ID: "user123", + Email: "test@example.com", + Firstname: "John", + Middlename: "", + Lastname: "Doe", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/systemusers/user123", r.URL.Path) + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "test-api-key", r.Header.Get("x-api-key")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(userResponse) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + userData, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{WTAccountID: "acc1"}) + require.NoError(t, err) + + assert.Equal(t, "user123", userData.ID) + assert.Equal(t, "test@example.com", userData.Email) + assert.Equal(t, "John Doe", userData.Name) + assert.Equal(t, "acc1", userData.AppMetadata.WTAccountID) +} + +func TestJumpCloudGetAccount(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/search/systemusers", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + + var reqBody map[string]any + assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + assert.Contains(t, reqBody, "limit") + assert.Contains(t, reqBody, "skip") + + resp := jumpCloudUserList{ + Results: []jumpCloudUser{ + {ID: "u1", Email: "a@test.com", Firstname: "Alice", Lastname: "Smith"}, + {ID: "u2", Email: "b@test.com", Firstname: "Bob", Lastname: "Jones"}, + }, + TotalCount: 2, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + users, err := manager.GetAccount(context.Background(), "testAccount") + require.NoError(t, err) + assert.Len(t, users, 2) + assert.Equal(t, "testAccount", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "testAccount", users[1].AppMetadata.WTAccountID) +} + +func TestJumpCloudGetAllAccounts(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := jumpCloudUserList{ + Results: []jumpCloudUser{ + {ID: "u1", Email: "a@test.com", Firstname: "Alice"}, + {ID: "u2", Email: "b@test.com", Firstname: "Bob"}, + }, + TotalCount: 2, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + indexedUsers, err := manager.GetAllAccounts(context.Background()) + require.NoError(t, err) + assert.Len(t, indexedUsers[UnsetAccountID], 2) +} + +func TestJumpCloudGetAllAccountsPagination(t *testing.T) { + totalUsers := 250 + allUsers := make([]jumpCloudUser, totalUsers) + for i := range allUsers { + allUsers[i] = jumpCloudUser{ + ID: fmt.Sprintf("u%d", i), + Email: fmt.Sprintf("user%d@test.com", i), + Firstname: fmt.Sprintf("User%d", i), + } + } + + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]int + assert.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + + limit := reqBody["limit"] + skip := reqBody["skip"] + requestCount++ + + end := skip + limit + if end > totalUsers { + end = totalUsers + } + + resp := jumpCloudUserList{ + Results: allUsers[skip:end], + TotalCount: totalUsers, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + indexedUsers, err := manager.GetAllAccounts(context.Background()) + require.NoError(t, err) + assert.Len(t, indexedUsers[UnsetAccountID], totalUsers) + assert.Equal(t, 3, requestCount, "should require 3 pages for 250 users at page size 100") +} + +func TestJumpCloudGetUserByEmail(t *testing.T) { + searchResponse := jumpCloudUserList{ + Results: []jumpCloudUser{ + {ID: "u1", Email: "alice@test.com", Firstname: "Alice", Lastname: "Smith"}, + }, + TotalCount: 1, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/search/systemusers", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Contains(t, string(body), "alice@test.com") + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(searchResponse) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + users, err := manager.GetUserByEmail(context.Background(), "alice@test.com") + require.NoError(t, err) + assert.Len(t, users, 1) + assert.Equal(t, "alice@test.com", users[0].Email) +} + +func TestJumpCloudDeleteUser(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/systemusers/user123", r.URL.Path) + assert.Equal(t, http.MethodDelete, r.Method) + assert.Equal(t, "test-api-key", r.Header.Get("x-api-key")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"_id": "user123"}) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + err := manager.DeleteUser(context.Background(), "user123") + require.NoError(t, err) +} + +func TestJumpCloudAPIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + manager := newTestJumpCloudManager(t, server.URL) + + _, err := manager.GetUserDataByID(context.Background(), "user123", AppMetadata{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +func TestParseJumpCloudUser(t *testing.T) { + user := jumpCloudUser{ + ID: "abc123", + Email: "test@example.com", + Firstname: "John", + Middlename: "M", + Lastname: "Doe", + } + + userData := parseJumpCloudUser(user) + assert.Equal(t, "abc123", userData.ID) + assert.Equal(t, "test@example.com", userData.Email) + assert.Equal(t, "John M Doe", userData.Name) +} + +func newTestJumpCloudManager(t *testing.T, apiBase string) *JumpCloudManager { + t.Helper() + return &JumpCloudManager{ + apiBase: apiBase, + apiToken: "test-api-key", + httpClient: http.DefaultClient, + helper: JsonParser{}, + appMetrics: nil, + } +} diff --git a/management/server/peer.go b/management/server/peer.go index 78ecbfcae..a02e34e0d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if err != nil { newLabel = "" } else { - _, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name) + _, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, newLabel) if err == nil { newLabel = "" } @@ -859,7 +859,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName } - am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if !temporary { + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + } if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) @@ -1480,9 +1482,11 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - peerDeletedEvents = append(peerDeletedEvents, func() { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) - }) + if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") { + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + } } return peerDeletedEvents, nil diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b17757ffd..51c16d730 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -37,6 +37,7 @@ import ( "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" @@ -2738,3 +2739,70 @@ func TestProcessPeerAddAuth(t *testing.T) { assert.Empty(t, config.GroupsToAdd) }) } + +func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + // Add first peer with hostname that produces DNS label "netbird1" + key1, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"}, + }, false) + require.NoError(t, err, "unable to add first peer") + assert.Equal(t, "netbird1", peer1.DNSLabel) + + // Add second peer with a different hostname + key2, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"}, + }, false) + require.NoError(t, err) + + update := peer2.Copy() + update.Name = "netbird1.demo.netbird.cloud" + updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update) + require.NoError(t, err, "renaming peer should not fail with duplicate DNS label error") + assert.Equal(t, "netbird1.demo.netbird.cloud", updated.Name) + assert.NotEqual(t, "netbird1", updated.DNSLabel, "DNS label should not collide with existing peer") + assert.Contains(t, updated.DNSLabel, "netbird1-", "DNS label should be IP-based fallback") +} + +func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + key1, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"}, + }, false) + require.NoError(t, err) + assert.Equal(t, "web-server", peer1.DNSLabel) + + // Add second peer and rename it to a unique FQDN whose first label doesn't collide + key2, err := wgtypes.GenerateKey() + require.NoError(t, err) + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: key2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"}, + }, false) + require.NoError(t, err) + + update := peer2.Copy() + update.Name = "api-server.example.com" + updated, err := manager.UpdatePeer(context.Background(), accountID, userID, update) + require.NoError(t, err, "renaming to unique FQDN should succeed") + assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN") +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index b3fbfe141..ee1947b18 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4997,7 +4997,6 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpse return service, nil } - func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { @@ -5408,17 +5407,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { return nil } -// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy -func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error { +// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist +func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + now := time.Now() + result := s.db.WithContext(ctx). Model(&proxy.Proxy{}). Where("id = ? AND status = ?", proxyID, "connected"). - Update("last_seen", time.Now()) + Update("last_seen", now) if result.Error != nil { log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error) return status.Errorf(status.Internal, "failed to update proxy heartbeat") } + + if result.RowsAffected == 0 { + p := &proxy.Proxy{ + ID: proxyID, + ClusterAddress: clusterAddress, + IPAddress: ipAddress, + LastSeen: now, + ConnectedAt: &now, + Status: "connected", + } + if err := s.db.WithContext(ctx).Save(p).Error; err != nil { + log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) + return status.Errorf(status.Internal, "failed to create proxy on heartbeat") + } + } + return nil } @@ -5428,7 +5445,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string result := s.db.WithContext(ctx). Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5440,6 +5457,81 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string return addresses, nil } +// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { + var clusters []proxy.Cluster + + result := s.db.Model(&proxy.Proxy{}). + Select("cluster_address as address, COUNT(*) as connected_proxies"). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Group("cluster_address"). + Scan(&clusters) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error) + return nil, status.Errorf(status.Internal, "get active proxy clusters") + } + + return clusters, nil +} + +// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be +// considered active. Must be at least 2x the heartbeat interval (1 min). +const proxyActiveThreshold = 2 * time.Minute + +var validCapabilityColumns = map[string]struct{}{ + "supports_custom_ports": {}, + "require_subdomain": {}, +} + +// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports") +} + +// GetClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "require_subdomain") +} + +// getClusterCapability returns an aggregated boolean capability for the given +// cluster. It checks active (connected, recently seen) proxies and returns: +// - *true if any proxy in the cluster has the capability set to true, +// - *false if at least one proxy reported but none set it to true, +// - nil if no proxy reported the capability at all. +func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool { + if _, ok := validCapabilityColumns[column]; !ok { + log.WithContext(ctx).Errorf("invalid capability column: %s", column) + return nil + } + + var result struct { + HasCapability bool + AnyTrue bool + } + + err := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+ + "COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true"). + Where("cluster_address = ? AND status = ? AND last_seen > ?", + clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). + Scan(&result).Error + + if err != nil { + log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) + return nil + } + + if !result.HasCapability { + return nil + } + + return &result.AnyTrue +} + // CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { cutoffTime := time.Now().Add(-inactivityDuration) @@ -5459,3 +5551,61 @@ func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration t return nil } + +// GetRoutingPeerNetworks returns the distinct network names where the peer is assigned as a routing peer +// in an enabled network router, either directly or via peer groups. +func (s *SqlStore) GetRoutingPeerNetworks(_ context.Context, accountID, peerID string) ([]string, error) { + var routers []*routerTypes.NetworkRouter + if err := s.db.Select("peer, peer_groups, network_id").Where("account_id = ? AND enabled = true", accountID).Find(&routers).Error; err != nil { + return nil, status.Errorf(status.Internal, "failed to get enabled routers: %v", err) + } + + if len(routers) == 0 { + return nil, nil + } + + var groupPeers []types.GroupPeer + if err := s.db.Select("group_id").Where("account_id = ? AND peer_id = ?", accountID, peerID).Find(&groupPeers).Error; err != nil { + return nil, status.Errorf(status.Internal, "failed to get peer group memberships: %v", err) + } + + groupSet := make(map[string]struct{}, len(groupPeers)) + for _, gp := range groupPeers { + groupSet[gp.GroupID] = struct{}{} + } + + networkIDs := make(map[string]struct{}) + for _, r := range routers { + if r.Peer == peerID { + networkIDs[r.NetworkID] = struct{}{} + } else if r.Peer == "" { + for _, pg := range r.PeerGroups { + if _, ok := groupSet[pg]; ok { + networkIDs[r.NetworkID] = struct{}{} + break + } + } + } + } + + if len(networkIDs) == 0 { + return nil, nil + } + + ids := make([]string, 0, len(networkIDs)) + for id := range networkIDs { + ids = append(ids, id) + } + + var networks []*networkTypes.Network + if err := s.db.Select("name").Where("account_id = ? AND id IN ?", accountID, ids).Find(&networks).Error; err != nil { + return nil, status.Errorf(status.Internal, "failed to get networks: %v", err) + } + + names := make([]string, 0, len(networks)) + for _, n := range networks { + names = append(names, n.Name) + } + + return names, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 8bb52f38a..e24a1efef 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -284,11 +284,16 @@ type Store interface { DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID string) error + UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) + + GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) } const ( diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index e75e35b94..a8648aed7 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -165,6 +165,34 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } +// GetClusterSupportsCustomPorts mocks base method. +func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts. +func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) +} + +// GetClusterRequireSubdomain mocks base method. +func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain. +func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) +} + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -1287,6 +1315,21 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) } +// GetActiveProxyClusters mocks base method. +func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) +} + // GetAllAccounts mocks base method. func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { m.ctrl.T.Helper() @@ -2318,6 +2361,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) } +// GetRoutingPeerNetworks mocks base method. +func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks. +func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID) +} + // IsPrimaryAccount mocks base method. func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { m.ctrl.T.Helper() @@ -2924,17 +2982,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} } // UpdateProxyHeartbeat mocks base method. -func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error { +func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID) + ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress) ret0, _ := ret[0].(error) return ret0 } // UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. -func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress) } // UpdateService mocks base method. diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index a2252cc20..1c36ee334 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -62,6 +62,7 @@ var ( proxyProtocol bool preSharedKey string supportsCustomPorts bool + requireSubdomain bool geoDataDir string ) @@ -101,6 +102,7 @@ func init() { rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") + rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain") rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)") rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)") rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)") @@ -181,6 +183,7 @@ func runServer(cmd *cobra.Command, args []string) error { ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, SupportsCustomPorts: supportsCustomPorts, + RequireSubdomain: requireSubdomain, MaxDialTimeout: maxDialTimeout, MaxSessionIdleTimeout: maxSessionIdleTimeout, GeoDataDir: geoDataDir, diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index a4924d380..6063f070e 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -932,3 +932,71 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) { assert.Equal(t, "header-user", capturedData2.GetUserID()) assert.Equal(t, "header", capturedData2.GetAuthMethod()) } + +// TestProtect_HeaderAuth_MultipleValuesSameHeader verifies that the proxy +// correctly handles multiple valid credentials for the same header name. +// In production, the mgmt gRPC authenticateHeader iterates all configured +// header auths and accepts if any hash matches (OR semantics). The proxy +// creates one Header scheme per entry, but a single gRPC call checks all. +func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + // Mock simulates mgmt behavior: accepts either token-a or token-b. + accepted := map[string]bool{"Bearer token-a": true, "Bearer token-b": true} + mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + ha := req.GetHeaderAuth() + if ha != nil && accepted[ha.GetHeaderValue()] { + token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour) + require.NoError(t, err) + return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil + } + return &proto.AuthenticateResponse{Success: false}, nil + }} + + // Single Header scheme (as if one entry existed), but the mock checks both values. + hdr := NewHeader(mock, "svc1", "acc1", "Authorization") + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + + var backendCalled bool + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + })) + + t.Run("first value accepted", func(t *testing.T) { + backendCalled = false + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer token-a") + req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData(""))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, backendCalled, "first token should be accepted") + }) + + t.Run("second value accepted", func(t *testing.T) { + backendCalled = false + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer token-b") + req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData(""))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, backendCalled, "second token should be accepted") + }) + + t.Run("unknown value rejected", func(t *testing.T) { + backendCalled = false + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.Header.Set("Authorization", "Bearer token-c") + req = req.WithContext(proxy.WithCapturedData(req.Context(), proxy.NewCapturedData(""))) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.False(t, backendCalled, "unknown token should be rejected") + }) +} diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 237010922..c507cfad9 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -409,17 +409,13 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc } pbStatus := nbstatus.ToProtoFullStatus(fullStatus) - overview := nbstatus.ConvertToStatusOutputOverview( - pbStatus, - false, - version.NetbirdVersion(), - statusFilter, - prefixNamesFilter, - prefixNamesFilterMap, - ipsFilterMap, - connectionTypeFilter, - "", - ) + overview := nbstatus.ConvertToStatusOutputOverview(pbStatus, nbstatus.ConvertOptions{ + StatusFilter: statusFilter, + PrefixNamesFilter: prefixNamesFilter, + PrefixNamesFilterMap: prefixNamesFilterMap, + IPsFilter: ipsFilterMap, + ConnectionTypeFilter: connectionTypeFilter, + }) if wantJSON { h.writeJSON(w, map[string]interface{}{ diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 8af151446..796cad622 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -200,7 +200,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { return nil } @@ -208,7 +208,7 @@ func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { return nil } -func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error { +func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { return nil } @@ -216,6 +216,18 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin return nil, nil } +func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { + return nil, nil +} + +func (m *testProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } @@ -243,10 +255,6 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string { return nil } -func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { - return nil -} - // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store @@ -323,6 +331,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} +func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { + return nil, nil +} + func strPtr(s string) *string { return &s } diff --git a/proxy/server.go b/proxy/server.go index c4d12859b..acfe3c12d 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -163,6 +163,10 @@ type Server struct { // SupportsCustomPorts indicates whether the proxy can bind arbitrary // ports for TCP/UDP/TLS services. SupportsCustomPorts bool + // RequireSubdomain indicates whether a subdomain label is required + // in front of this proxy's cluster domain. When true, accounts cannot + // create services on the bare cluster domain. + RequireSubdomain bool // MaxDialTimeout caps the per-service backend dial timeout. // When the API sends a timeout, it is clamped to this value. // When the API sends no timeout, this value is used as the default. @@ -919,6 +923,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Address: s.ProxyURL, Capabilities: &proto.ProxyCapabilities{ SupportsCustomPorts: &s.SupportsCustomPorts, + RequireSubdomain: &s.RequireSubdomain, }, }) if err != nil { diff --git a/release_files/install.sh b/release_files/install.sh index 6a2c5f458..1e71936f3 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -128,7 +128,7 @@ cat <<-EOF | ${SUDO} tee /etc/yum.repos.d/netbird.repo name=NetBird baseurl=https://pkgs.netbird.io/yum/ enabled=1 -gpgcheck=0 +gpgcheck=1 gpgkey=https://pkgs.netbird.io/yum/repodata/repomd.xml.key repo_gpgcheck=1 EOF diff --git a/shared/management/client/client.go b/shared/management/client/client.go index ba525602e..a15301223 100644 --- a/shared/management/client/client.go +++ b/shared/management/client/client.go @@ -22,6 +22,7 @@ type Client interface { GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) + GetServerURL() string IsHealthy() bool SyncMeta(sysInfo *system.Info) error Logout() error diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 333f0bf00..252199498 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "io" + "os" + "strconv" "sync" "time" @@ -29,6 +31,10 @@ import ( const ConnectTimeout = 10 * time.Second const ( + // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) + // for the management client connection. Value is in bytes. + EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE" + errMsgMgmtPublicKey = "failed getting Management Service public key: %s" errMsgNoMgmtConnection = "no connection to management" ) @@ -46,6 +52,7 @@ type GrpcClient struct { conn *grpc.ClientConn connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + serverURL string } type ExposeRequest struct { @@ -66,13 +73,41 @@ type ExposeResponse struct { PortAutoAssigned bool } +// MaxRecvMsgSize returns the configured max gRPC receive message size from +// the environment, or 0 if unset (which uses the gRPC default of 4 MB). +func MaxRecvMsgSize() int { + val := os.Getenv(EnvMaxRecvMsgSize) + if val == "" { + return 0 + } + + size, err := strconv.Atoi(val) + if err != nil { + log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err) + return 0 + } + + if size <= 0 { + log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size) + return 0 + } + + return size +} + // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { var conn *grpc.ClientConn + var extraOpts []grpc.DialOption + if maxSize := MaxRecvMsgSize(); maxSize > 0 { + extraOpts = append(extraOpts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxSize))) + log.Infof("management gRPC max receive message size set to %d bytes", maxSize) + } + operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent, extraOpts...) if err != nil { return fmt.Errorf("create connection: %w", err) } @@ -93,9 +128,15 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE ctx: ctx, conn: conn, connStateCallbackLock: sync.RWMutex{}, + serverURL: addr, }, nil } +// GetServerURL returns the management server URL +func (c *GrpcClient) GetServerURL() string { + return c.serverURL +} + // Close closes connection to the Management Service func (c *GrpcClient) Close() error { return c.conn.Close() diff --git a/shared/management/client/grpc_test.go b/shared/management/client/grpc_test.go new file mode 100644 index 000000000..462cc43af --- /dev/null +++ b/shared/management/client/grpc_test.go @@ -0,0 +1,95 @@ +package client + +import ( + "context" + "net" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestMaxRecvMsgSize(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {name: "unset returns 0", envValue: "", expected: 0}, + {name: "valid value", envValue: "10485760", expected: 10485760}, + {name: "non-numeric returns 0", envValue: "abc", expected: 0}, + {name: "negative returns 0", envValue: "-1", expected: 0}, + {name: "zero returns 0", envValue: "0", expected: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(EnvMaxRecvMsgSize, tt.envValue) + if tt.envValue == "" { + os.Unsetenv(EnvMaxRecvMsgSize) + } + assert.Equal(t, tt.expected, MaxRecvMsgSize()) + }) + } +} + +// largeSyncServer implements just the Sync RPC, returning a response larger than the default 4MB limit. +type largeSyncServer struct { + mgmtProto.UnimplementedManagementServiceServer + responseSize int +} + +func (s *largeSyncServer) GetServerKey(_ context.Context, _ *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) { + // Return a response with a large WiretrusteeConfig to exceed the default limit. + padding := strings.Repeat("x", s.responseSize) + return &mgmtProto.ServerKeyResponse{ + Key: padding, + }, nil +} + +func TestMaxRecvMsgSizeIntegration(t *testing.T) { + const payloadSize = 5 * 1024 * 1024 // 5MB, exceeds 4MB default + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + srv := grpc.NewServer() + mgmtProto.RegisterManagementServiceServer(srv, &largeSyncServer{responseSize: payloadSize}) + go func() { _ = srv.Serve(lis) }() + t.Cleanup(srv.Stop) + + t.Run("default limit rejects large message", func(t *testing.T) { + conn, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer conn.Close() + + client := mgmtProto.NewManagementServiceClient(conn) + _, err = client.GetServerKey(context.Background(), &mgmtProto.Empty{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "received message larger than max") + }) + + t.Run("increased limit accepts large message", func(t *testing.T) { + conn, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(10*1024*1024)), + ) + require.NoError(t, err) + defer conn.Close() + + client := mgmtProto.NewManagementServiceClient(conn) + resp, err := client.GetServerKey(context.Background(), &mgmtProto.Empty{}) + require.NoError(t, err) + assert.Len(t, resp.Key, payloadSize) + }) +} diff --git a/shared/management/client/mock.go b/shared/management/client/mock.go index 57256d6d4..548e379e8 100644 --- a/shared/management/client/mock.go +++ b/shared/management/client/mock.go @@ -19,6 +19,7 @@ type MockClient struct { LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) + GetServerURLFunc func() string SyncMetaFunc func(sysInfo *system.Info) error LogoutFunc func() error JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error @@ -92,6 +93,14 @@ func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) { return nil, nil } +// GetServerURL mock implementation of GetServerURL from mgm.Client interface +func (m *MockClient) GetServerURL() string { + if m.GetServerURLFunc == nil { + return "" + } + return m.GetServerURLFunc() +} + func (m *MockClient) SyncMeta(sysInfo *system.Info) error { if m.SyncMetaFunc == nil { return nil diff --git a/shared/management/client/rest/azure_idp.go b/shared/management/client/rest/azure_idp.go new file mode 100644 index 000000000..40b90bc30 --- /dev/null +++ b/shared/management/client/rest/azure_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// AzureIDPAPI APIs for Azure AD IDP integrations +type AzureIDPAPI struct { + c *Client +} + +// List retrieves all Azure AD IDP integrations +func (a *AzureIDPAPI) List(ctx context.Context) ([]api.AzureIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.AzureIntegration](resp) + return ret, err +} + +// Get retrieves a specific Azure AD IDP integration by ID +func (a *AzureIDPAPI) Get(ctx context.Context, integrationID string) (*api.AzureIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Create creates a new Azure AD IDP integration +func (a *AzureIDPAPI) Create(ctx context.Context, request api.CreateAzureIntegrationRequest) (*api.AzureIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Update updates an existing Azure AD IDP integration +func (a *AzureIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateAzureIntegrationRequest) (*api.AzureIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/azure-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.AzureIntegration](resp) + return &ret, err +} + +// Delete deletes an Azure AD IDP integration +func (a *AzureIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/azure-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Sync triggers a manual sync for an Azure AD IDP integration +func (a *AzureIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp/"+integrationID+"/sync", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.SyncResult](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for an Azure AD IDP integration +func (a *AzureIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/azure_idp_test.go b/shared/management/client/rest/azure_idp_test.go new file mode 100644 index 000000000..480d2a313 --- /dev/null +++ b/shared/management/client/rest/azure_idp_test.go @@ -0,0 +1,252 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testAzureIntegration = api.AzureIntegration{ + Id: 1, + Enabled: true, + ClientId: "12345678-1234-1234-1234-123456789012", + TenantId: "87654321-4321-4321-4321-210987654321", + SyncInterval: 300, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + Host: "microsoft.com", + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestAzureIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.AzureIntegration{testAzureIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testAzureIntegration, ret[0]) + }) +} + +func TestAzureIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestAzureIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateAzureIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "12345678-1234-1234-1234-123456789012", req.ClientId) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{ + ClientId: "12345678-1234-1234-1234-123456789012", + ClientSecret: "secret", + TenantId: "87654321-4321-4321-4321-210987654321", + Host: api.CreateAzureIntegrationRequestHostMicrosoftCom, + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateAzureIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testAzureIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testAzureIntegration, *ret) + }) +} + +func TestAzureIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.AzureIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestAzureIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.AzureIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestAzureIDP_Sync_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Sync(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, "ok", *ret.Result) + }) +} + +func TestAzureIDP_Sync_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.Sync(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestAzureIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestAzureIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index f308761fb..f0cb4d2d1 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -110,6 +110,15 @@ type Client struct { // see more: https://docs.netbird.io/api/resources/scim SCIM *SCIMAPI + // GoogleIDP NetBird Google Workspace IDP integration APIs + GoogleIDP *GoogleIDPAPI + + // AzureIDP NetBird Azure AD IDP integration APIs + AzureIDP *AzureIDPAPI + + // OktaScimIDP NetBird Okta SCIM IDP integration APIs + OktaScimIDP *OktaScimIDPAPI + // EventStreaming NetBird Event Streaming integration APIs // see more: https://docs.netbird.io/api/resources/event-streaming EventStreaming *EventStreamingAPI @@ -185,6 +194,9 @@ func (c *Client) initialize() { c.MSP = &MSPAPI{c} c.EDR = &EDRAPI{c} c.SCIM = &SCIMAPI{c} + c.GoogleIDP = &GoogleIDPAPI{c} + c.AzureIDP = &AzureIDPAPI{c} + c.OktaScimIDP = &OktaScimIDPAPI{c} c.EventStreaming = &EventStreamingAPI{c} c.IdentityProviders = &IdentityProvidersAPI{c} c.Ingress = &IngressAPI{c} diff --git a/shared/management/client/rest/google_idp.go b/shared/management/client/rest/google_idp.go new file mode 100644 index 000000000..b86436503 --- /dev/null +++ b/shared/management/client/rest/google_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// GoogleIDPAPI APIs for Google Workspace IDP integrations +type GoogleIDPAPI struct { + c *Client +} + +// List retrieves all Google Workspace IDP integrations +func (a *GoogleIDPAPI) List(ctx context.Context) ([]api.GoogleIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.GoogleIntegration](resp) + return ret, err +} + +// Get retrieves a specific Google Workspace IDP integration by ID +func (a *GoogleIDPAPI) Get(ctx context.Context, integrationID string) (*api.GoogleIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Create creates a new Google Workspace IDP integration +func (a *GoogleIDPAPI) Create(ctx context.Context, request api.CreateGoogleIntegrationRequest) (*api.GoogleIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Update updates an existing Google Workspace IDP integration +func (a *GoogleIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateGoogleIntegrationRequest) (*api.GoogleIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/google-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.GoogleIntegration](resp) + return &ret, err +} + +// Delete deletes a Google Workspace IDP integration +func (a *GoogleIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/google-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// Sync triggers a manual sync for a Google Workspace IDP integration +func (a *GoogleIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp/"+integrationID+"/sync", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.SyncResult](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for a Google Workspace IDP integration +func (a *GoogleIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/google_idp_test.go b/shared/management/client/rest/google_idp_test.go new file mode 100644 index 000000000..03a6c161e --- /dev/null +++ b/shared/management/client/rest/google_idp_test.go @@ -0,0 +1,248 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testGoogleIntegration = api.GoogleIntegration{ + Id: 1, + Enabled: true, + CustomerId: "C01234567", + SyncInterval: 300, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestGoogleIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.GoogleIntegration{testGoogleIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testGoogleIntegration, ret[0]) + }) +} + +func TestGoogleIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestGoogleIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateGoogleIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "C01234567", req.CustomerId) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{ + CustomerId: "C01234567", + ServiceAccountKey: "key-data", + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateGoogleIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testGoogleIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testGoogleIntegration, *ret) + }) +} + +func TestGoogleIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.GoogleIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestGoogleIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.GoogleIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestGoogleIDP_Sync_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Sync(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, "ok", *ret.Result) + }) +} + +func TestGoogleIDP_Sync_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.Sync(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestGoogleIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestGoogleIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/client/rest/okta_scim_idp.go b/shared/management/client/rest/okta_scim_idp.go new file mode 100644 index 000000000..eb677dae8 --- /dev/null +++ b/shared/management/client/rest/okta_scim_idp.go @@ -0,0 +1,112 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// OktaScimIDPAPI APIs for Okta SCIM IDP integrations +type OktaScimIDPAPI struct { + c *Client +} + +// List retrieves all Okta SCIM IDP integrations +func (a *OktaScimIDPAPI) List(ctx context.Context) ([]api.OktaScimIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.OktaScimIntegration](resp) + return ret, err +} + +// Get retrieves a specific Okta SCIM IDP integration by ID +func (a *OktaScimIDPAPI) Get(ctx context.Context, integrationID string) (*api.OktaScimIntegration, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Create creates a new Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Create(ctx context.Context, request api.CreateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Update updates an existing Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/okta-scim-idp/"+integrationID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.OktaScimIntegration](resp) + return &ret, err +} + +// Delete deletes an Okta SCIM IDP integration +func (a *OktaScimIDPAPI) Delete(ctx context.Context, integrationID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} + +// RegenerateToken regenerates the SCIM API token for an Okta SCIM integration +func (a *OktaScimIDPAPI) RegenerateToken(ctx context.Context, integrationID string) (*api.ScimTokenResponse, error) { + resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp/"+integrationID+"/token", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.ScimTokenResponse](resp) + return &ret, err +} + +// GetLogs retrieves synchronization logs for an Okta SCIM IDP integration +func (a *OktaScimIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID+"/logs", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp) + return ret, err +} diff --git a/shared/management/client/rest/okta_scim_idp_test.go b/shared/management/client/rest/okta_scim_idp_test.go new file mode 100644 index 000000000..d8d1f2b51 --- /dev/null +++ b/shared/management/client/rest/okta_scim_idp_test.go @@ -0,0 +1,246 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var testOktaScimIntegration = api.OktaScimIntegration{ + Id: 1, + AuthToken: "****", + Enabled: true, + GroupPrefixes: []string{"eng-"}, + UserGroupPrefixes: []string{"dev-"}, + LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), +} + +func TestOktaScimIDP_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.OktaScimIntegration{testOktaScimIntegration}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.List(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testOktaScimIntegration, ret[0]) + }) +} + +func TestOktaScimIDP_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.List(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestOktaScimIDP_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Get(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Get(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.CreateOktaScimIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "my-okta-connection", req.ConnectionName) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{ + ConnectionName: "my-okta-connection", + GroupPrefixes: &[]string{"eng-"}, + }) + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.UpdateOktaScimIntegrationRequest + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, true, *req.Enabled) + retBytes, _ := json.Marshal(testOktaScimIntegration) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{ + Enabled: ptr(true), + }) + require.NoError(t, err) + assert.Equal(t, testOktaScimIntegration, *ret) + }) +} + +func TestOktaScimIDP_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{}) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.OktaScimIDP.Delete(context.Background(), "int-1") + require.NoError(t, err) + }) +} + +func TestOktaScimIDP_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.OktaScimIDP.Delete(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestOktaScimIDP_RegenerateToken_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + retBytes, _ := json.Marshal(testScimToken) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1") + require.NoError(t, err) + assert.Equal(t, testScimToken, *ret) + }) +} + +func TestOktaScimIDP_RegenerateToken_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestOktaScimIDP_GetLogs_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testSyncLog, ret[0]) + }) +} + +func TestOktaScimIDP_GetLogs_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 66f39b92f..833468676 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -68,8 +68,17 @@ tags: - name: MSP description: MSP portal for Tenant management. x-cloud-only: true - - name: IDP - description: Manage identity provider integrations for user and group sync. + - name: IDP SCIM Integrations + description: Manage generic SCIM identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Google Integrations + description: Manage Google Workspace identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Azure Integrations + description: Manage Azure AD identity provider integrations for user and group sync. + x-cloud-only: true + - name: IDP Okta SCIM Integrations + description: Manage Okta SCIM identity provider integrations for user and group sync. x-cloud-only: true - name: EDR Intune Integrations description: Manage Microsoft Intune EDR integrations. @@ -89,6 +98,10 @@ tags: - name: Event Streaming Integrations description: Manage event streaming integrations. x-cloud-only: true + - name: Notifications + description: Manage notification channels for account event alerts. + x-cloud-only: true + components: schemas: @@ -2995,6 +3008,11 @@ components: type: boolean description: Whether the service is enabled example: true + terminated: + type: boolean + description: Whether the service has been terminated. Terminated services cannot be updated. Services that violate the Terms of Service will be terminated. + readOnly: true + example: false pass_host_header: type: boolean description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address @@ -3336,6 +3354,10 @@ components: type: boolean description: Whether the cluster supports binding arbitrary TCP/UDP ports example: true + require_subdomain: + type: boolean + description: Whether a subdomain label is required in front of this domain. When true, the domain cannot be used bare. + example: false required: - id - domain @@ -4254,96 +4276,89 @@ components: description: Status of agent firewall. Can be one of Disabled, Enabled, Pending Isolation, Isolated, Pending Release. example: "Enabled" + IntegrationSyncFilters: + type: object + properties: + group_prefixes: + type: array + description: List of start_with string patterns for groups to sync + items: + type: string + example: [ "Engineering", "Sales" ] + user_group_prefixes: + type: array + description: List of start_with string patterns for groups which users to sync + items: + type: string + example: [ "Users" ] + IntegrationEnabled: + type: object + properties: + enabled: + type: boolean + description: Whether the integration is enabled + example: true CreateScimIntegrationRequest: - type: object - description: Request payload for creating an SCIM IDP integration - required: - - prefix - - provider - properties: - prefix: - type: string - description: The connection prefix used for the SCIM provider - provider: - type: string - description: Name of the SCIM identity provider - group_prefixes: - type: array - description: List of start_with string patterns for groups to sync - items: - type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an SCIM IDP integration + required: + - prefix + - provider + properties: + prefix: + type: string + description: The connection prefix used for the SCIM provider + provider: + type: string + description: Name of the SCIM identity provider UpdateScimIntegrationRequest: - type: object - description: Request payload for updating an SCIM IDP integration - properties: - enabled: - type: boolean - description: Indicates whether the integration is enabled - example: true - group_prefixes: - type: array - description: List of start_with string patterns for groups to sync - items: - type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an SCIM IDP integration + properties: + prefix: + type: string + description: The connection prefix used for the SCIM provider ScimIntegration: - type: object - description: Represents a SCIM IDP integration - required: - - id - - enabled - - provider - - group_prefixes - - user_group_prefixes - - auth_token - - last_synced_at - properties: - id: - type: integer - format: int64 - description: The unique identifier for the integration - example: 123 - enabled: - type: boolean - description: Indicates whether the integration is enabled - example: true - provider: - type: string - description: Name of the SCIM identity provider - group_prefixes: - type: array - description: List of start_with string patterns for groups to sync - items: - type: string - example: [ "Engineering", "Sales" ] - user_group_prefixes: - type: array - description: List of start_with string patterns for groups which users to sync - items: - type: string - example: [ "Users" ] - auth_token: - type: string - description: SCIM API token (full on creation, masked otherwise) - example: "nbs_abc***********************************" - last_synced_at: - type: string - format: date-time - description: Timestamp of when the integration was last synced - example: "2023-05-15T10:30:00Z" + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents a SCIM IDP integration + required: + - id + - enabled + - prefix + - provider + - group_prefixes + - user_group_prefixes + - auth_token + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 123 + prefix: + type: string + description: The connection prefix used for the SCIM provider + provider: + type: string + description: Name of the SCIM identity provider + auth_token: + type: string + description: SCIM API token (full on creation, masked otherwise) + example: "nbs_abc***********************************" + last_synced_at: + type: string + format: date-time + description: Timestamp of when the integration was last synced + example: "2023-05-15T10:30:00Z" IdpIntegrationSyncLog: type: object description: Represents a synchronization log entry for an integration @@ -4381,6 +4396,346 @@ components: type: string description: The newly generated SCIM API token example: "nbs_F3f0d..." + CreateGoogleIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating a Google Workspace IDP integration + required: + - service_account_key + - customer_id + properties: + service_account_key: + type: string + description: Base64-encoded Google service account key + example: "eyJ0eXBlIjoic2VydmljZV9hY2NvdW50Ii..." + customer_id: + type: string + description: Customer ID from Google Workspace Account Settings + example: "C01234567" + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + minimum: 300 + example: 300 + UpdateGoogleIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating a Google Workspace IDP integration. All fields are optional. + properties: + service_account_key: + type: string + description: Base64-encoded Google service account key + customer_id: + type: string + description: Customer ID from Google Workspace Account Settings + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300) + minimum: 300 + GoogleIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents a Google Workspace IDP integration + required: + - id + - customer_id + - sync_interval + - enabled + - group_prefixes + - user_group_prefixes + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + customer_id: + type: string + description: Customer ID from Google Workspace + example: "C01234567" + sync_interval: + type: integer + description: Sync interval in seconds + example: 300 + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + CreateAzureIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an Azure AD IDP integration + required: + - client_secret + - client_id + - tenant_id + - host + properties: + client_secret: + type: string + description: Base64-encoded Azure AD client secret + example: "c2VjcmV0..." + client_id: + type: string + description: Azure AD application (client) ID + example: "12345678-1234-1234-1234-123456789012" + tenant_id: + type: string + description: Azure AD tenant ID + example: "87654321-4321-4321-4321-210987654321" + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + minimum: 300 + example: 300 + host: + type: string + description: Azure host domain for the Graph API + enum: + - microsoft.com + - microsoft.us + example: "microsoft.com" + UpdateAzureIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an Azure AD IDP integration. All fields are optional. + properties: + client_secret: + type: string + description: Base64-encoded Azure AD client secret + client_id: + type: string + description: Azure AD application (client) ID + tenant_id: + type: string + description: Azure AD tenant ID + sync_interval: + type: integer + description: Sync interval in seconds (minimum 300) + minimum: 300 + AzureIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents an Azure AD IDP integration + required: + - id + - client_id + - tenant_id + - sync_interval + - enabled + - group_prefixes + - user_group_prefixes + - host + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + client_id: + type: string + description: Azure AD application (client) ID + example: "12345678-1234-1234-1234-123456789012" + tenant_id: + type: string + description: Azure AD tenant ID + example: "87654321-4321-4321-4321-210987654321" + sync_interval: + type: integer + description: Sync interval in seconds + example: 300 + host: + type: string + description: Azure host domain for the Graph API + example: "microsoft.com" + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + CreateOktaScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for creating an Okta SCIM IDP integration + required: + - connection_name + properties: + connection_name: + type: string + description: The Okta enterprise connection name on Auth0 + example: "my-okta-connection" + UpdateOktaScimIntegrationRequest: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Request payload for updating an Okta SCIM IDP integration. All fields are optional. + OktaScimIntegration: + allOf: + - $ref: '#/components/schemas/IntegrationEnabled' + - $ref: '#/components/schemas/IntegrationSyncFilters' + - type: object + description: Represents an Okta SCIM IDP integration + required: + - id + - enabled + - group_prefixes + - user_group_prefixes + - auth_token + - last_synced_at + properties: + id: + type: integer + format: int64 + description: The unique identifier for the integration + example: 1 + auth_token: + type: string + description: SCIM API token (full on creation/regeneration, masked on retrieval) + example: "nbs_abc***********************************" + last_synced_at: + type: string + format: date-time + description: Timestamp of the last synchronization + example: "2023-05-15T10:30:00Z" + SyncResult: + type: object + description: Response for a manual sync trigger + properties: + result: + type: string + example: "ok" + NotificationChannelType: + type: string + description: The type of notification channel. + enum: + - email + - webhook + example: "email" + NotificationEventType: + type: string + description: | + An activity event type code. See `GET /api/integrations/notifications/types` for the full list + of supported event types and their human-readable descriptions. + example: "user.join" + EmailTarget: + type: object + description: Target configuration for email notification channels. + properties: + emails: + type: array + description: List of email addresses to send notifications to. + minItems: 1 + items: + type: string + format: email + example: [ "admin@example.com", "ops@example.com" ] + required: + - emails + WebhookTarget: + type: object + description: Target configuration for webhook notification channels. + properties: + url: + type: string + format: uri + description: The webhook endpoint URL to send notifications to. + example: "https://hooks.example.com/netbird" + headers: + type: object + additionalProperties: + type: string + description: | + Custom HTTP headers sent with each webhook request. + Values are write-only; in GET responses all values are masked. + example: + Authorization: "Bearer token" + X-Webhook-Secret: "secret" + required: + - url + NotificationChannelRequest: + type: object + description: Request body for creating or updating a notification channel. + properties: + type: + $ref: '#/components/schemas/NotificationChannelType' + target: + description: | + Channel-specific target configuration. The shape depends on the `type` field: + - `email`: requires an `EmailTarget` object + - `webhook`: requires a `WebhookTarget` object + oneOf: + - $ref: '#/components/schemas/EmailTarget' + - $ref: '#/components/schemas/WebhookTarget' + event_types: + type: array + description: List of activity event type codes this channel subscribes to. + items: + $ref: '#/components/schemas/NotificationEventType' + example: [ "user.join", "peer.user.add", "peer.login.expire" ] + enabled: + type: boolean + description: Whether this notification channel is active. + example: true + required: + - type + - event_types + - enabled + NotificationChannelResponse: + type: object + description: A notification channel configuration. + properties: + id: + type: string + description: Unique identifier of the notification channel. + readOnly: true + example: "ch8i4ug6lnn4g9hqv7m0" + type: + $ref: '#/components/schemas/NotificationChannelType' + target: + description: | + Channel-specific target configuration. The shape depends on the `type` field: + - `email`: an `EmailTarget` object + - `webhook`: a `WebhookTarget` object + oneOf: + - $ref: '#/components/schemas/EmailTarget' + - $ref: '#/components/schemas/WebhookTarget' + event_types: + type: array + description: List of activity event type codes this channel subscribes to. + items: + $ref: '#/components/schemas/NotificationEventType' + example: [ "user.join", "peer.user.add", "peer.login.expire" ] + enabled: + type: boolean + description: Whether this notification channel is active. + example: true + required: + - id + - type + - event_types + - enabled + NotificationTypeEntry: + type: object + description: A map of event type codes to their human-readable descriptions. + additionalProperties: + type: string + example: + user.join: "User joined" BypassResponse: type: object description: Response for bypassed peer operations. @@ -9017,10 +9372,877 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp: + post: + tags: + - IDP Google Integrations + summary: Create Google IDP Integration + description: Creates a new Google Workspace IDP integration + operationId: createGoogleIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateGoogleIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Google Integrations + summary: Get All Google IDP Integrations + description: Retrieves all Google Workspace IDP integrations for the authenticated account + operationId: getAllGoogleIntegrations + responses: + '200': + description: A list of Google IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/GoogleIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Google Integrations + summary: Get Google IDP Integration + description: Retrieves a Google IDP integration by ID. + operationId: getGoogleIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Google Integrations + summary: Update Google IDP Integration + description: Updates an existing Google Workspace IDP integration. + operationId: updateGoogleIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateGoogleIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/GoogleIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Google Integrations + summary: Delete Google IDP Integration + description: Deletes a Google IDP integration by ID. + operationId: deleteGoogleIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}/sync: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Google Integrations + summary: Sync Google IDP Integration + description: Triggers a manual synchronization for a Google IDP integration. + operationId: syncGoogleIntegration + responses: + '200': + description: Sync triggered successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/SyncResult' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/google-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Google IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Google Integrations + summary: Get Google Integration Sync Logs + description: Retrieves synchronization logs for a Google IDP integration. + operationId: getGoogleIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp: + post: + tags: + - IDP Azure Integrations + summary: Create Azure IDP Integration + description: Creates a new Azure AD IDP integration + operationId: createAzureIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateAzureIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Azure Integrations + summary: Get All Azure IDP Integrations + description: Retrieves all Azure AD IDP integrations for the authenticated account + operationId: getAllAzureIntegrations + responses: + '200': + description: A list of Azure IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AzureIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Azure Integrations + summary: Get Azure IDP Integration + description: Retrieves an Azure IDP integration by ID. + operationId: getAzureIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Azure Integrations + summary: Update Azure IDP Integration + description: Updates an existing Azure AD IDP integration. + operationId: updateAzureIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateAzureIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/AzureIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Azure Integrations + summary: Delete Azure IDP Integration + description: Deletes an Azure IDP integration by ID. + operationId: deleteAzureIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}/sync: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Azure Integrations + summary: Sync Azure IDP Integration + description: Triggers a manual synchronization for an Azure IDP integration. + operationId: syncAzureIntegration + responses: + '200': + description: Sync triggered successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/SyncResult' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/azure-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Azure IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Azure Integrations + summary: Get Azure Integration Sync Logs + description: Retrieves synchronization logs for an Azure IDP integration. + operationId: getAzureIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp: + post: + tags: + - IDP Okta SCIM Integrations + summary: Create Okta SCIM IDP Integration + description: Creates a new Okta SCIM IDP integration + operationId: createOktaScimIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateOktaScimIntegrationRequest' + responses: + '200': + description: Integration created successfully. Returns the created integration. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid JSON, missing required fields, validation error). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized (e.g., missing or invalid authentication token). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + get: + tags: + - IDP Okta SCIM Integrations + summary: Get All Okta SCIM IDP Integrations + description: Retrieves all Okta SCIM IDP integrations for the authenticated account + operationId: getAllOktaScimIntegrations + responses: + '200': + description: A list of Okta SCIM IDP integrations. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/OktaScimIntegration' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Okta SCIM Integrations + summary: Get Okta SCIM IDP Integration + description: Retrieves an Okta SCIM IDP integration by ID. + operationId: getOktaScimIntegration + responses: + '200': + description: Successfully retrieved the integration details. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found (e.g., integration with the given ID does not exist). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + put: + tags: + - IDP Okta SCIM Integrations + summary: Update Okta SCIM IDP Integration + description: Updates an existing Okta SCIM IDP integration. + operationId: updateOktaScimIntegration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateOktaScimIntegrationRequest' + responses: + '200': + description: Integration updated successfully. Returns the updated integration. + content: + application/json: + schema: + $ref: '#/components/schemas/OktaScimIntegration' + '400': + description: Bad Request (e.g., invalid JSON, validation error, invalid ID). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + delete: + tags: + - IDP Okta SCIM Integrations + summary: Delete Okta SCIM IDP Integration + description: Deletes an Okta SCIM IDP integration by ID. + operationId: deleteOktaScimIntegration + responses: + '200': + description: Integration deleted successfully. Returns an empty object. + content: + application/json: + schema: + type: object + example: { } + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}/token: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + post: + tags: + - IDP Okta SCIM Integrations + summary: Regenerate Okta SCIM Token + description: Regenerates the SCIM API token for an Okta SCIM IDP integration. + operationId: regenerateOktaScimToken + responses: + '200': + description: Token regenerated successfully. Returns the new token. + content: + application/json: + schema: + $ref: '#/components/schemas/ScimTokenResponse' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /api/integrations/okta-scim-idp/{id}/logs: + parameters: + - name: id + in: path + required: true + description: The unique identifier of the Okta SCIM IDP integration. + schema: + type: integer + format: int64 + example: 1 + get: + tags: + - IDP Okta SCIM Integrations + summary: Get Okta SCIM Integration Sync Logs + description: Retrieves synchronization logs for an Okta SCIM IDP integration. + operationId: getOktaScimIntegrationLogs + responses: + '200': + description: Successfully retrieved the integration sync logs. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IdpIntegrationSyncLog' + '400': + description: Bad Request (e.g., invalid integration ID format). + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not Found. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal Server Error. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/integrations/scim-idp: post: tags: - - IDP + - IDP SCIM Integrations summary: Create SCIM IDP Integration description: Creates a new SCIM integration operationId: createSCIMIntegration @@ -9057,7 +10279,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' get: tags: - - IDP + - IDP SCIM Integrations summary: Get All SCIM IDP Integrations description: Retrieves all SCIM IDP integrations for the authenticated account operationId: getAllSCIMIntegrations @@ -9089,11 +10311,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 get: tags: - - IDP + - IDP SCIM Integrations summary: Get SCIM IDP Integration description: Retrieves an SCIM IDP integration by ID. operationId: getSCIMIntegration @@ -9130,7 +10353,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' put: tags: - - IDP + - IDP SCIM Integrations summary: Update SCIM IDP Integration description: Updates an existing SCIM IDP Integration. operationId: updateSCIMIntegration @@ -9173,7 +10396,7 @@ paths: $ref: '#/components/schemas/ErrorResponse' delete: tags: - - IDP + - IDP SCIM Integrations summary: Delete SCIM IDP Integration description: Deletes an SCIM IDP integration by ID. operationId: deleteSCIMIntegration @@ -9216,11 +10439,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 post: tags: - - IDP + - IDP SCIM Integrations summary: Regenerate SCIM Token description: Regenerates the SCIM API token for an SCIM IDP integration. operationId: regenerateSCIMToken @@ -9262,11 +10486,12 @@ paths: required: true description: The unique identifier of the SCIM IDP integration. schema: - type: string - example: "ch8i4ug6lnn4g9hqv7m0" + type: integer + format: int64 + example: 1 get: tags: - - IDP + - IDP SCIM Integrations summary: Get SCIM Integration Sync Logs description: Retrieves synchronization logs for a SCIM IDP integration. operationId: getSCIMIntegrationLogs @@ -10058,3 +11283,172 @@ paths: "$ref": "#/components/responses/not_found" '500': "$ref": "#/components/responses/internal_error" + /api/integrations/notifications/types: + get: + tags: + - Notifications + summary: List Notification Event Types + description: | + Returns a map of all supported activity event type codes to their + human-readable descriptions. Use these codes when configuring + `event_types` on notification channels. + operationId: listNotificationEventTypes + responses: + '200': + description: A map of event type codes to descriptions. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationTypeEntry' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/integrations/notifications/channels: + get: + tags: + - Notifications + summary: List Notification Channels + description: Retrieves all notification channels configured for the authenticated account. + operationId: listNotificationChannels + responses: + '200': + description: A list of notification channels. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + post: + tags: + - Notifications + summary: Create Notification Channel + description: | + Creates a new notification channel for the authenticated account. + Supported channel types are `email` and `webhook`. + operationId: createNotificationChannel + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelRequest' + responses: + '200': + description: Notification channel created successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/integrations/notifications/channels/{channelId}: + parameters: + - name: channelId + in: path + required: true + description: The unique identifier of the notification channel. + schema: + type: string + example: "ch8i4ug6lnn4g9hqv7m0" + get: + tags: + - Notifications + summary: Get Notification Channel + description: Retrieves a specific notification channel by its ID. + operationId: getNotificationChannel + responses: + '200': + description: Successfully retrieved the notification channel. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + tags: + - Notifications + summary: Update Notification Channel + description: Updates an existing notification channel. + operationId: updateNotificationChannel + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelRequest' + responses: + '200': + description: Notification channel updated successfully. + content: + application/json: + schema: + $ref: '#/components/schemas/NotificationChannelResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + delete: + tags: + - Notifications + summary: Delete Notification Channel + description: Deletes a notification channel by its ID. + operationId: deleteNotificationChannel + responses: + '200': + description: Notification channel deleted successfully. + content: + application/json: + schema: + type: object + example: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 693449d14..fb9976c89 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -9,6 +9,7 @@ import ( "time" "github.com/oapi-codegen/runtime" + openapi_types "github.com/oapi-codegen/runtime/types" ) const ( @@ -16,6 +17,24 @@ const ( TokenAuthScopes = "TokenAuth.Scopes" ) +// Defines values for CreateAzureIntegrationRequestHost. +const ( + CreateAzureIntegrationRequestHostMicrosoftCom CreateAzureIntegrationRequestHost = "microsoft.com" + CreateAzureIntegrationRequestHostMicrosoftUs CreateAzureIntegrationRequestHost = "microsoft.us" +) + +// Valid indicates whether the value is a known member of the CreateAzureIntegrationRequestHost enum. +func (e CreateAzureIntegrationRequestHost) Valid() bool { + switch e { + case CreateAzureIntegrationRequestHostMicrosoftCom: + return true + case CreateAzureIntegrationRequestHostMicrosoftUs: + return true + default: + return false + } +} + // Defines values for CreateIntegrationRequestPlatform. const ( CreateIntegrationRequestPlatformDatadog CreateIntegrationRequestPlatform = "datadog" @@ -664,6 +683,24 @@ func (e NetworkResourceType) Valid() bool { } } +// Defines values for NotificationChannelType. +const ( + NotificationChannelTypeEmail NotificationChannelType = "email" + NotificationChannelTypeWebhook NotificationChannelType = "webhook" +) + +// Valid indicates whether the value is a known member of the NotificationChannelType enum. +func (e NotificationChannelType) Valid() bool { + switch e { + case NotificationChannelTypeEmail: + return true + case NotificationChannelTypeWebhook: + return true + default: + return false + } +} + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" @@ -1450,6 +1487,36 @@ type AvailablePorts struct { Udp int `json:"udp"` } +// AzureIntegration defines model for AzureIntegration. +type AzureIntegration struct { + // ClientId Azure AD application (client) ID + ClientId string `json:"client_id"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Host Azure host domain for the Graph API + Host string `json:"host"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // SyncInterval Sync interval in seconds + SyncInterval int `json:"sync_interval"` + + // TenantId Azure AD tenant ID + TenantId string `json:"tenant_id"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // BearerAuthConfig defines model for BearerAuthConfig. type BearerAuthConfig struct { // DistributionGroups List of group IDs that can use bearer auth @@ -1557,6 +1624,51 @@ type Country struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country type CountryCode = string +// CreateAzureIntegrationRequest defines model for CreateAzureIntegrationRequest. +type CreateAzureIntegrationRequest struct { + // ClientId Azure AD application (client) ID + ClientId string `json:"client_id"` + + // ClientSecret Base64-encoded Azure AD client secret + ClientSecret string `json:"client_secret"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // Host Azure host domain for the Graph API + Host CreateAzureIntegrationRequestHost `json:"host"` + + // SyncInterval Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + SyncInterval *int `json:"sync_interval,omitempty"` + + // TenantId Azure AD tenant ID + TenantId string `json:"tenant_id"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// CreateAzureIntegrationRequestHost Azure host domain for the Graph API +type CreateAzureIntegrationRequestHost string + +// CreateGoogleIntegrationRequest defines model for CreateGoogleIntegrationRequest. +type CreateGoogleIntegrationRequest struct { + // CustomerId Customer ID from Google Workspace Account Settings + CustomerId string `json:"customer_id"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // ServiceAccountKey Base64-encoded Google service account key + ServiceAccountKey string `json:"service_account_key"` + + // SyncInterval Sync interval in seconds (minimum 300). Defaults to 300 if not specified. + SyncInterval *int `json:"sync_interval,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // CreateIntegrationRequest Request payload for creating a new event streaming integration. Also used as the structure for the PUT request body, but not all fields are applicable for updates (see PUT operation description). type CreateIntegrationRequest struct { // Config Platform-specific configuration as key-value pairs. For creation, all necessary credentials and settings must be provided. For updates, provide the fields to change or the entire new configuration. @@ -1572,7 +1684,19 @@ type CreateIntegrationRequest struct { // CreateIntegrationRequestPlatform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend. type CreateIntegrationRequestPlatform string -// CreateScimIntegrationRequest Request payload for creating an SCIM IDP integration +// CreateOktaScimIntegrationRequest defines model for CreateOktaScimIntegrationRequest. +type CreateOktaScimIntegrationRequest struct { + // ConnectionName The Okta enterprise connection name on Auth0 + ConnectionName string `json:"connection_name"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// CreateScimIntegrationRequest defines model for CreateScimIntegrationRequest. type CreateScimIntegrationRequest struct { // GroupPrefixes List of start_with string patterns for groups to sync GroupPrefixes *[]string `json:"group_prefixes,omitempty"` @@ -1893,6 +2017,12 @@ type EDRSentinelOneResponse struct { UpdatedAt time.Time `json:"updated_at"` } +// EmailTarget Target configuration for email notification channels. +type EmailTarget struct { + // Emails List of email addresses to send notifications to. + Emails []openapi_types.Email `json:"emails"` +} + // ErrorResponse Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. type ErrorResponse struct { // Message A human-readable error message. @@ -1947,6 +2077,30 @@ type GeoLocationCheckAction string // GetTenantsResponse defines model for GetTenantsResponse. type GetTenantsResponse = []TenantResponse +// GoogleIntegration defines model for GoogleIntegration. +type GoogleIntegration struct { + // CustomerId Customer ID from Google Workspace + CustomerId string `json:"customer_id"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // SyncInterval Sync interval in seconds + SyncInterval int `json:"sync_interval"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // Group defines model for Group. type Group struct { // Id Group ID @@ -2238,6 +2392,12 @@ type InstanceVersionInfo struct { ManagementUpdateAvailable bool `json:"management_update_available"` } +// IntegrationEnabled defines model for IntegrationEnabled. +type IntegrationEnabled struct { + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` +} + // IntegrationResponse Represents an event streaming integration. type IntegrationResponse struct { // AccountId The identifier of the account this integration belongs to. @@ -2265,6 +2425,15 @@ type IntegrationResponse struct { // IntegrationResponsePlatform The event streaming platform. type IntegrationResponsePlatform string +// IntegrationSyncFilters defines model for IntegrationSyncFilters. +type IntegrationSyncFilters struct { + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + // InvoicePDFResponse defines model for InvoicePDFResponse. type InvoicePDFResponse struct { // Url URL to redirect the user to invoice. @@ -2666,6 +2835,67 @@ type NetworkTrafficUser struct { Name string `json:"name"` } +// NotificationChannelRequest Request body for creating or updating a notification channel. +type NotificationChannelRequest struct { + // Enabled Whether this notification channel is active. + Enabled bool `json:"enabled"` + + // EventTypes List of activity event type codes this channel subscribes to. + EventTypes []NotificationEventType `json:"event_types"` + + // Target Channel-specific target configuration. The shape depends on the `type` field: + // - `email`: requires an `EmailTarget` object + // - `webhook`: requires a `WebhookTarget` object + Target *NotificationChannelRequest_Target `json:"target,omitempty"` + + // Type The type of notification channel. + Type NotificationChannelType `json:"type"` +} + +// NotificationChannelRequest_Target Channel-specific target configuration. The shape depends on the `type` field: +// - `email`: requires an `EmailTarget` object +// - `webhook`: requires a `WebhookTarget` object +type NotificationChannelRequest_Target struct { + union json.RawMessage +} + +// NotificationChannelResponse A notification channel configuration. +type NotificationChannelResponse struct { + // Enabled Whether this notification channel is active. + Enabled bool `json:"enabled"` + + // EventTypes List of activity event type codes this channel subscribes to. + EventTypes []NotificationEventType `json:"event_types"` + + // Id Unique identifier of the notification channel. + Id *string `json:"id,omitempty"` + + // Target Channel-specific target configuration. The shape depends on the `type` field: + // - `email`: an `EmailTarget` object + // - `webhook`: a `WebhookTarget` object + Target *NotificationChannelResponse_Target `json:"target,omitempty"` + + // Type The type of notification channel. + Type NotificationChannelType `json:"type"` +} + +// NotificationChannelResponse_Target Channel-specific target configuration. The shape depends on the `type` field: +// - `email`: an `EmailTarget` object +// - `webhook`: a `WebhookTarget` object +type NotificationChannelResponse_Target struct { + union json.RawMessage +} + +// NotificationChannelType The type of notification channel. +type NotificationChannelType string + +// NotificationEventType An activity event type code. See `GET /api/integrations/notifications/types` for the full list +// of supported event types and their human-readable descriptions. +type NotificationEventType = string + +// NotificationTypeEntry A map of event type codes to their human-readable descriptions. +type NotificationTypeEntry map[string]string + // OSVersionCheck Posture check for the version of operating system type OSVersionCheck struct { // Android Posture check for the version of operating system @@ -2684,6 +2914,27 @@ type OSVersionCheck struct { Windows *MinKernelVersionCheck `json:"windows,omitempty"` } +// OktaScimIntegration defines model for OktaScimIntegration. +type OktaScimIntegration struct { + // AuthToken SCIM API token (full on creation/regeneration, masked on retrieval) + AuthToken string `json:"auth_token"` + + // Enabled Whether the integration is enabled + Enabled bool `json:"enabled"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes []string `json:"group_prefixes"` + + // Id The unique identifier for the integration + Id int64 `json:"id"` + + // LastSyncedAt Timestamp of the last synchronization + LastSyncedAt time.Time `json:"last_synced_at"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes []string `json:"user_group_prefixes"` +} + // PINAuthConfig defines model for PINAuthConfig. type PINAuthConfig struct { // Enabled Whether PIN auth is enabled @@ -3406,6 +3657,9 @@ type ReverseProxyDomain struct { // Id Domain ID Id string `json:"id"` + // RequireSubdomain Whether a subdomain label is required in front of this domain. When true, the domain cannot be used bare. + RequireSubdomain *bool `json:"require_subdomain,omitempty"` + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` @@ -3530,12 +3784,12 @@ type RulePortRange struct { Start int `json:"start"` } -// ScimIntegration Represents a SCIM IDP integration +// ScimIntegration defines model for ScimIntegration. type ScimIntegration struct { // AuthToken SCIM API token (full on creation, masked otherwise) AuthToken string `json:"auth_token"` - // Enabled Indicates whether the integration is enabled + // Enabled Whether the integration is enabled Enabled bool `json:"enabled"` // GroupPrefixes List of start_with string patterns for groups to sync @@ -3547,6 +3801,9 @@ type ScimIntegration struct { // LastSyncedAt Timestamp of when the integration was last synced LastSyncedAt time.Time `json:"last_synced_at"` + // Prefix The connection prefix used for the SCIM provider + Prefix string `json:"prefix"` + // Provider Name of the SCIM identity provider Provider string `json:"provider"` @@ -3629,6 +3886,9 @@ type Service struct { // Targets List of target backends for this service Targets []ServiceTarget `json:"targets"` + + // Terminated Whether the service has been terminated. Terminated services cannot be updated. Services that violate the Terms of Service will be terminated. + Terminated *bool `json:"terminated,omitempty"` } // ServiceMode Service mode. "http" for L7 reverse proxy, "tcp"/"udp"/"tls" for L4 passthrough. @@ -3948,6 +4208,11 @@ type Subscription struct { UpdatedAt time.Time `json:"updated_at"` } +// SyncResult Response for a manual sync trigger +type SyncResult struct { + Result *string `json:"result,omitempty"` +} + // TenantGroupResponse defines model for TenantGroupResponse. type TenantGroupResponse struct { // Id The Group ID @@ -3993,14 +4258,74 @@ type TenantResponse struct { // TenantResponseStatus The status of the tenant type TenantResponseStatus string -// UpdateScimIntegrationRequest Request payload for updating an SCIM IDP integration -type UpdateScimIntegrationRequest struct { - // Enabled Indicates whether the integration is enabled +// UpdateAzureIntegrationRequest defines model for UpdateAzureIntegrationRequest. +type UpdateAzureIntegrationRequest struct { + // ClientId Azure AD application (client) ID + ClientId *string `json:"client_id,omitempty"` + + // ClientSecret Base64-encoded Azure AD client secret + ClientSecret *string `json:"client_secret,omitempty"` + + // Enabled Whether the integration is enabled Enabled *bool `json:"enabled,omitempty"` // GroupPrefixes List of start_with string patterns for groups to sync GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + // SyncInterval Sync interval in seconds (minimum 300) + SyncInterval *int `json:"sync_interval,omitempty"` + + // TenantId Azure AD tenant ID + TenantId *string `json:"tenant_id,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateGoogleIntegrationRequest defines model for UpdateGoogleIntegrationRequest. +type UpdateGoogleIntegrationRequest struct { + // CustomerId Customer ID from Google Workspace Account Settings + CustomerId *string `json:"customer_id,omitempty"` + + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // ServiceAccountKey Base64-encoded Google service account key + ServiceAccountKey *string `json:"service_account_key,omitempty"` + + // SyncInterval Sync interval in seconds (minimum 300) + SyncInterval *int `json:"sync_interval,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateOktaScimIntegrationRequest defines model for UpdateOktaScimIntegrationRequest. +type UpdateOktaScimIntegrationRequest struct { + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // UserGroupPrefixes List of start_with string patterns for groups which users to sync + UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` +} + +// UpdateScimIntegrationRequest defines model for UpdateScimIntegrationRequest. +type UpdateScimIntegrationRequest struct { + // Enabled Whether the integration is enabled + Enabled *bool `json:"enabled,omitempty"` + + // GroupPrefixes List of start_with string patterns for groups to sync + GroupPrefixes *[]string `json:"group_prefixes,omitempty"` + + // Prefix The connection prefix used for the SCIM provider + Prefix *string `json:"prefix,omitempty"` + // UserGroupPrefixes List of start_with string patterns for groups which users to sync UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"` } @@ -4208,6 +4533,16 @@ type UserRequest struct { Role string `json:"role"` } +// WebhookTarget Target configuration for webhook notification channels. +type WebhookTarget struct { + // Headers Custom HTTP headers sent with each webhook request. + // Values are write-only; in GET responses all values are masked. + Headers *map[string]string `json:"headers,omitempty"` + + // Url The webhook endpoint URL to send notifications to. + Url string `json:"url"` +} + // WorkloadRequest defines model for WorkloadRequest. type WorkloadRequest struct { union json.RawMessage @@ -4510,6 +4845,12 @@ type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest // PutApiIngressPeersIngressPeerIdJSONRequestBody defines body for PutApiIngressPeersIngressPeerId for application/json ContentType. type PutApiIngressPeersIngressPeerIdJSONRequestBody = IngressPeerUpdateRequest +// CreateAzureIntegrationJSONRequestBody defines body for CreateAzureIntegration for application/json ContentType. +type CreateAzureIntegrationJSONRequestBody = CreateAzureIntegrationRequest + +// UpdateAzureIntegrationJSONRequestBody defines body for UpdateAzureIntegration for application/json ContentType. +type UpdateAzureIntegrationJSONRequestBody = UpdateAzureIntegrationRequest + // PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceActivate for application/json ContentType. type PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody @@ -4546,6 +4887,12 @@ type CreateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest // UpdateSentinelOneEDRIntegrationJSONRequestBody defines body for UpdateSentinelOneEDRIntegration for application/json ContentType. type UpdateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest +// CreateGoogleIntegrationJSONRequestBody defines body for CreateGoogleIntegration for application/json ContentType. +type CreateGoogleIntegrationJSONRequestBody = CreateGoogleIntegrationRequest + +// UpdateGoogleIntegrationJSONRequestBody defines body for UpdateGoogleIntegration for application/json ContentType. +type UpdateGoogleIntegrationJSONRequestBody = UpdateGoogleIntegrationRequest + // PostApiIntegrationsMspTenantsJSONRequestBody defines body for PostApiIntegrationsMspTenants for application/json ContentType. type PostApiIntegrationsMspTenantsJSONRequestBody = CreateTenantRequest @@ -4561,6 +4908,18 @@ type PostApiIntegrationsMspTenantsIdSubscriptionJSONRequestBody PostApiIntegrati // PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody defines body for PostApiIntegrationsMspTenantsIdUnlink for application/json ContentType. type PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody PostApiIntegrationsMspTenantsIdUnlinkJSONBody +// CreateNotificationChannelJSONRequestBody defines body for CreateNotificationChannel for application/json ContentType. +type CreateNotificationChannelJSONRequestBody = NotificationChannelRequest + +// UpdateNotificationChannelJSONRequestBody defines body for UpdateNotificationChannel for application/json ContentType. +type UpdateNotificationChannelJSONRequestBody = NotificationChannelRequest + +// CreateOktaScimIntegrationJSONRequestBody defines body for CreateOktaScimIntegration for application/json ContentType. +type CreateOktaScimIntegrationJSONRequestBody = CreateOktaScimIntegrationRequest + +// UpdateOktaScimIntegrationJSONRequestBody defines body for UpdateOktaScimIntegration for application/json ContentType. +type UpdateOktaScimIntegrationJSONRequestBody = UpdateOktaScimIntegrationRequest + // CreateSCIMIntegrationJSONRequestBody defines body for CreateSCIMIntegration for application/json ContentType. type CreateSCIMIntegrationJSONRequestBody = CreateScimIntegrationRequest @@ -4657,6 +5016,130 @@ type PutApiUsersUserIdPasswordJSONRequestBody = PasswordChangeRequest // PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType. type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest +// AsEmailTarget returns the union data inside the NotificationChannelRequest_Target as a EmailTarget +func (t NotificationChannelRequest_Target) AsEmailTarget() (EmailTarget, error) { + var body EmailTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromEmailTarget overwrites any union data inside the NotificationChannelRequest_Target as the provided EmailTarget +func (t *NotificationChannelRequest_Target) FromEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeEmailTarget performs a merge with any union data inside the NotificationChannelRequest_Target, using the provided EmailTarget +func (t *NotificationChannelRequest_Target) MergeEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsWebhookTarget returns the union data inside the NotificationChannelRequest_Target as a WebhookTarget +func (t NotificationChannelRequest_Target) AsWebhookTarget() (WebhookTarget, error) { + var body WebhookTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWebhookTarget overwrites any union data inside the NotificationChannelRequest_Target as the provided WebhookTarget +func (t *NotificationChannelRequest_Target) FromWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWebhookTarget performs a merge with any union data inside the NotificationChannelRequest_Target, using the provided WebhookTarget +func (t *NotificationChannelRequest_Target) MergeWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t NotificationChannelRequest_Target) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *NotificationChannelRequest_Target) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// AsEmailTarget returns the union data inside the NotificationChannelResponse_Target as a EmailTarget +func (t NotificationChannelResponse_Target) AsEmailTarget() (EmailTarget, error) { + var body EmailTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromEmailTarget overwrites any union data inside the NotificationChannelResponse_Target as the provided EmailTarget +func (t *NotificationChannelResponse_Target) FromEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeEmailTarget performs a merge with any union data inside the NotificationChannelResponse_Target, using the provided EmailTarget +func (t *NotificationChannelResponse_Target) MergeEmailTarget(v EmailTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsWebhookTarget returns the union data inside the NotificationChannelResponse_Target as a WebhookTarget +func (t NotificationChannelResponse_Target) AsWebhookTarget() (WebhookTarget, error) { + var body WebhookTarget + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromWebhookTarget overwrites any union data inside the NotificationChannelResponse_Target as the provided WebhookTarget +func (t *NotificationChannelResponse_Target) FromWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeWebhookTarget performs a merge with any union data inside the NotificationChannelResponse_Target, using the provided WebhookTarget +func (t *NotificationChannelResponse_Target) MergeWebhookTarget(v WebhookTarget) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t NotificationChannelResponse_Target) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *NotificationChannelResponse_Target) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + // AsBundleWorkloadRequest returns the union data inside the WorkloadRequest as a BundleWorkloadRequest func (t WorkloadRequest) AsBundleWorkloadRequest() (BundleWorkloadRequest, error) { var body BundleWorkloadRequest diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index e5a2d6a98..93295e857 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -181,8 +181,11 @@ type ProxyCapabilities struct { state protoimpl.MessageState `protogen:"open.v1"` // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. SupportsCustomPorts *bool `protobuf:"varint,1,opt,name=supports_custom_ports,json=supportsCustomPorts,proto3,oneof" json:"supports_custom_ports,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Whether the proxy requires a subdomain label in front of its cluster domain. + // When true, tenants cannot use the cluster domain bare. + RequireSubdomain *bool `protobuf:"varint,2,opt,name=require_subdomain,json=requireSubdomain,proto3,oneof" json:"require_subdomain,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ProxyCapabilities) Reset() { @@ -222,6 +225,13 @@ func (x *ProxyCapabilities) GetSupportsCustomPorts() bool { return false } +func (x *ProxyCapabilities) GetRequireSubdomain() bool { + if x != nil && x.RequireSubdomain != nil { + return *x.RequireSubdomain + } + return false +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1872,10 +1882,12 @@ var File_proxy_service_proto protoreflect.FileDescriptor const file_proxy_service_proto_rawDesc = "" + "\n" + "\x13proxy_service.proto\x12\n" + - "management\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"f\n" + + "management\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xae\x01\n" + "\x11ProxyCapabilities\x127\n" + - "\x15supports_custom_ports\x18\x01 \x01(\bH\x00R\x13supportsCustomPorts\x88\x01\x01B\x18\n" + - "\x16_supports_custom_ports\"\xe6\x01\n" + + "\x15supports_custom_ports\x18\x01 \x01(\bH\x00R\x13supportsCustomPorts\x88\x01\x01\x120\n" + + "\x11require_subdomain\x18\x02 \x01(\bH\x01R\x10requireSubdomain\x88\x01\x01B\x18\n" + + "\x16_supports_custom_portsB\x14\n" + + "\x12_require_subdomain\"\xe6\x01\n" + "\x17GetMappingUpdateRequest\x12\x19\n" + "\bproxy_id\x18\x01 \x01(\tR\aproxyId\x12\x18\n" + "\aversion\x18\x02 \x01(\tR\aversion\x129\n" + diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index 2d7bed548..f77071eb0 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -31,6 +31,9 @@ service ProxyService { message ProxyCapabilities { // Whether the proxy can bind arbitrary ports for TCP/UDP/TLS services. optional bool supports_custom_ports = 1; + // Whether the proxy requires a subdomain label in front of its cluster domain. + // When true, accounts cannot use the cluster domain bare. + optional bool require_subdomain = 2; } // GetMappingUpdateRequest is sent to initialise a mapping stream. diff --git a/upload-server/server/s3_test.go b/upload-server/server/s3_test.go index 26b0ecd09..7ab1bb379 100644 --- a/upload-server/server/s3_test.go +++ b/upload-server/server/s3_test.go @@ -5,13 +5,12 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "os" "runtime" "testing" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -20,45 +19,55 @@ import ( ) func Test_S3HandlerGetUploadURL(t *testing.T) { - if runtime.GOOS != "linux" && os.Getenv("CI") == "true" { - t.Skip("Skipping test on non-Linux and CI environment due to docker dependency") - } - if runtime.GOOS == "windows" { - t.Skip("Skipping test on Windows due to potential docker dependency") + if runtime.GOOS != "linux" { + t.Skip("Skipping test on non-Linux due to docker dependency") } - awsEndpoint := "http://127.0.0.1:4566" awsRegion := "us-east-1" ctx := context.Background() - containerRequest := testcontainers.ContainerRequest{ - Image: "localstack/localstack:s3-latest", - ExposedPorts: []string{"4566:4566/tcp"}, - WaitingFor: wait.ForLog("Ready"), - } - c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: containerRequest, - Started: true, + ContainerRequest: testcontainers.ContainerRequest{ + Image: "minio/minio:RELEASE.2025-04-22T22-12-26Z", + ExposedPorts: []string{"9000/tcp"}, + Env: map[string]string{ + "MINIO_ROOT_USER": "minioadmin", + "MINIO_ROOT_PASSWORD": "minioadmin", + }, + Cmd: []string{"server", "/data"}, + WaitingFor: wait.ForHTTP("/minio/health/ready").WithPort("9000"), + }, + Started: true, }) - if err != nil { - t.Error(err) - } - defer func(c testcontainers.Container, ctx context.Context) { + require.NoError(t, err) + t.Cleanup(func() { if err := c.Terminate(ctx); err != nil { t.Log(err) } - }(c, ctx) + }) + + mappedPort, err := c.MappedPort(ctx, "9000") + require.NoError(t, err) + + hostIP, err := c.Host(ctx) + require.NoError(t, err) + + awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port() t.Setenv("AWS_REGION", awsRegion) t.Setenv("AWS_ENDPOINT_URL", awsEndpoint) - t.Setenv("AWS_ACCESS_KEY_ID", "test") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test") + t.Setenv("AWS_ACCESS_KEY_ID", "minioadmin") + t.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin") + t.Setenv("AWS_CONFIG_FILE", "") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "") + t.Setenv("AWS_PROFILE", "") - cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint)) - if err != nil { - t.Error(err) - } + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(awsRegion), + config.WithBaseEndpoint(awsEndpoint), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("minioadmin", "minioadmin", "")), + ) + require.NoError(t, err) client := s3.NewFromConfig(cfg, func(o *s3.Options) { o.UsePathStyle = true @@ -66,19 +75,16 @@ func Test_S3HandlerGetUploadURL(t *testing.T) { }) bucketName := "test" - if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{ + _, err = client.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: &bucketName, - }); err != nil { - t.Error(err) - } + }) + require.NoError(t, err) list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) - if err != nil { - t.Error(err) - } + require.NoError(t, err) - assert.Equal(t, len(list.Buckets), 1) - assert.Equal(t, *list.Buckets[0].Name, bucketName) + require.Len(t, list.Buckets, 1) + require.Equal(t, bucketName, *list.Buckets[0].Name) t.Setenv(bucketVar, bucketName) diff --git a/util/log.go b/util/log.go index 03547024a..b1de2d999 100644 --- a/util/log.go +++ b/util/log.go @@ -43,7 +43,13 @@ func InitLogger(logger *log.Logger, logLevel string, logs ...string) error { var writers []io.Writer logFmt := os.Getenv("NB_LOG_FORMAT") + seen := make(map[string]bool, len(logs)) for _, logPath := range logs { + if seen[logPath] { + continue + } + seen[logPath] = true + switch logPath { case LogSyslog: AddSyslogHookToLogger(logger)