mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 07:33:52 -04:00
Compare commits
12 Commits
nb-interfa
...
log-checks
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
affa8bf348 | ||
|
|
c435c2727f | ||
|
|
643730f770 | ||
|
|
04fae00a6c | ||
|
|
1a9ea32c21 | ||
|
|
0ea5d020a3 | ||
|
|
459c9ef317 | ||
|
|
e5e275c87a | ||
|
|
d311f57559 | ||
|
|
1a28d18cde | ||
|
|
91e7423989 | ||
|
|
86c16cf651 |
20
.github/workflows/golang-test-linux.yml
vendored
20
.github/workflows/golang-test-linux.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -24,8 +24,8 @@ jobs:
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
management:
|
||||
- 'management/**'
|
||||
management:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [build-cache]
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -181,6 +181,7 @@ jobs:
|
||||
env:
|
||||
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||
CONTAINER: "true"
|
||||
run: |
|
||||
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||
@@ -198,6 +199,7 @@ jobs:
|
||||
-e GOARCH=${GOARCH_TARGET} \
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.23-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
@@ -211,7 +213,11 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
include:
|
||||
- arch: "386"
|
||||
raceFlag: ""
|
||||
- arch: "amd64"
|
||||
raceFlag: ""
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -251,9 +257,9 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
-timeout 10m ./relay/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
|
||||
@@ -67,7 +67,6 @@ var (
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
serviceName string
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
@@ -116,15 +115,9 @@ func init() {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
}
|
||||
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
|
||||
@@ -135,7 +128,6 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
rootCmd.AddCommand(statusCmd)
|
||||
@@ -146,9 +138,6 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
@@ -186,14 +175,13 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
|
||||
termCh := make(chan os.Signal, 1)
|
||||
signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
done := ctx.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
case <-termCh:
|
||||
}
|
||||
|
||||
log.Info("shutdown signal received")
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@@ -14,6 +17,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
var serviceCmd = &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "manages Netbird service",
|
||||
}
|
||||
|
||||
var (
|
||||
serviceName string
|
||||
serviceEnvVars []string
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -22,12 +35,31 @@ type program struct {
|
||||
serverInstanceMu sync.Mutex
|
||||
}
|
||||
|
||||
func init() {
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return &program{ctx: ctx, cancel: cancel}
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
func newSVCConfig() (*service.Config, error) {
|
||||
config := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
@@ -36,23 +68,47 @@ func newSVCConfig() *service.Config {
|
||||
EnvVars: make(map[string]string),
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
extraEnvs, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse service environment variables: %w", err)
|
||||
}
|
||||
config.EnvVars = extraEnvs
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||
}
|
||||
|
||||
return config
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||
s, err := service.New(prg, conf)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return service.New(prg, conf)
|
||||
}
|
||||
|
||||
var serviceCmd = &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "manages Netbird service",
|
||||
func parseServiceEnvVars(envVars []string) (map[string]string, error) {
|
||||
envMap := make(map[string]string)
|
||||
|
||||
for _, env := range envVars {
|
||||
if env == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("empty environment variable key in: %s", env)
|
||||
}
|
||||
|
||||
envMap[key] = value
|
||||
}
|
||||
|
||||
return envMap, nil
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
@@ -47,14 +49,13 @@ func (p *program) Start(svc service.Service) error {
|
||||
|
||||
listen, err := net.Listen(split[0], split[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen daemon interface: %w", err)
|
||||
return fmt.Errorf("listen daemon interface: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer listen.Close()
|
||||
|
||||
if split[0] == "unix" {
|
||||
err = os.Chmod(split[1], 0666)
|
||||
if err != nil {
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
@@ -100,37 +101,49 @@ func (p *program) Stop(srv service.Service) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Common setup for service control commands
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
if err := handleRebrand(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := util.InitLog(logLevel, logFile); err != nil {
|
||||
return nil, fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create service config: %w", err)
|
||||
}
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var runCmd = &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "runs Netbird as service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.Run()
|
||||
},
|
||||
}
|
||||
|
||||
@@ -138,31 +151,14 @@ var startCmd = &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "starts Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
}
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
return fmt.Errorf("start service: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been started")
|
||||
return nil
|
||||
@@ -173,29 +169,14 @@ var stopCmd = &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "stops Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
return fmt.Errorf("stop service: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been stopped")
|
||||
return nil
|
||||
@@ -206,31 +187,48 @@ var restartCmd = &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "restarts Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Restart()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if err := s.Restart(); err != nil {
|
||||
return fmt.Errorf("restart service: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been restarted")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var svcStatusCmd = &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "shows Netbird service status",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get service status: %w", err)
|
||||
}
|
||||
|
||||
var statusText string
|
||||
switch status {
|
||||
case service.StatusRunning:
|
||||
statusText = "Running"
|
||||
case service.StatusStopped:
|
||||
statusText = "Stopped"
|
||||
case service.StatusUnknown:
|
||||
statusText = "Unknown"
|
||||
default:
|
||||
statusText = fmt.Sprintf("Unknown (%d)", status)
|
||||
}
|
||||
|
||||
cmd.Printf("Netbird service status: %s\n", statusText)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,87 +1,121 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var ErrGetServiceStatus = fmt.Errorf("failed to get service status")
|
||||
|
||||
// Common service command setup
|
||||
func setupServiceCommand(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
return handleRebrand(cmd)
|
||||
}
|
||||
|
||||
// Build service arguments for install/reconfigure
|
||||
func buildServiceArguments() []string {
|
||||
args := []string{
|
||||
"service",
|
||||
"run",
|
||||
"--config",
|
||||
configPath,
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
args = append(args, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
if logFile != "" {
|
||||
args = append(args, "--log-file", logFile)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// Configure platform-specific service settings
|
||||
func configurePlatformSpecificSettings(svcConfig *service.Config) error {
|
||||
if runtime.GOOS == "linux" {
|
||||
// Respected only by systemd systems
|
||||
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
|
||||
|
||||
if logFile != "console" {
|
||||
setStdLogPath := true
|
||||
dir := filepath.Dir(logFile)
|
||||
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
if err = os.MkdirAll(dir, 0750); err != nil {
|
||||
setStdLogPath = false
|
||||
}
|
||||
}
|
||||
|
||||
if setStdLogPath {
|
||||
svcConfig.Option["LogOutput"] = true
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
svcConfig.Option["OnFailure"] = "restart"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create fully configured service config for install/reconfigure
|
||||
func createServiceConfigForInstall() (*service.Config, error) {
|
||||
svcConfig, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create service config: %w", err)
|
||||
}
|
||||
|
||||
svcConfig.Arguments = buildServiceArguments()
|
||||
if err = configurePlatformSpecificSettings(svcConfig); err != nil {
|
||||
return nil, fmt.Errorf("configure platform-specific settings: %w", err)
|
||||
}
|
||||
|
||||
return svcConfig, nil
|
||||
}
|
||||
|
||||
var installCmd = &cobra.Command{
|
||||
Use: "install",
|
||||
Short: "installs Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
if err := setupServiceCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svcConfig := newSVCConfig()
|
||||
|
||||
svcConfig.Arguments = []string{
|
||||
"service",
|
||||
"run",
|
||||
"--config",
|
||||
configPath,
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
if logFile != "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
// Respected only by systemd systems
|
||||
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
|
||||
|
||||
if logFile != "console" {
|
||||
setStdLogPath := true
|
||||
dir := filepath.Dir(logFile)
|
||||
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0750)
|
||||
if err != nil {
|
||||
setStdLogPath = false
|
||||
}
|
||||
}
|
||||
|
||||
if setStdLogPath {
|
||||
svcConfig.Option["LogOutput"] = true
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
svcConfig.Option["OnFailure"] = "restart"
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.Install()
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
if err := s.Install(); err != nil {
|
||||
return fmt.Errorf("install service: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Netbird service has been installed")
|
||||
@@ -93,27 +127,109 @@ var uninstallCmd = &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "uninstalls Netbird service from system",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
if err := setupServiceCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create service config: %w", err)
|
||||
}
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.Uninstall(); err != nil {
|
||||
return fmt.Errorf("uninstall service: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Netbird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var reconfigureCmd = &cobra.Command{
|
||||
Use: "reconfigure",
|
||||
Short: "reconfigures Netbird service with new settings",
|
||||
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install.
|
||||
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := setupServiceCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wasRunning, err := isServiceRunning()
|
||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||
return fmt.Errorf("check service status: %w", err)
|
||||
}
|
||||
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
|
||||
err = s.Uninstall()
|
||||
if err != nil {
|
||||
return err
|
||||
if wasRunning {
|
||||
cmd.Println("Stopping Netbird service...")
|
||||
if err := s.Stop(); err != nil {
|
||||
cmd.Printf("Warning: failed to stop service: %v\n", err)
|
||||
}
|
||||
}
|
||||
cmd.Println("Netbird service has been uninstalled")
|
||||
|
||||
cmd.Println("Removing existing service configuration...")
|
||||
if err := s.Uninstall(); err != nil {
|
||||
return fmt.Errorf("uninstall existing service: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Installing service with new configuration...")
|
||||
if err := s.Install(); err != nil {
|
||||
return fmt.Errorf("install service with new config: %w", err)
|
||||
}
|
||||
|
||||
if wasRunning {
|
||||
cmd.Println("Starting Netbird service...")
|
||||
if err := s.Start(); err != nil {
|
||||
return fmt.Errorf("start service after reconfigure: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been reconfigured and started")
|
||||
} else {
|
||||
cmd.Println("Netbird service has been reconfigured")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func isServiceRunning() (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %w", ErrGetServiceStatus, err)
|
||||
}
|
||||
|
||||
return status == service.StatusRunning, nil
|
||||
}
|
||||
|
||||
263
client/cmd/service_test.go
Normal file
263
client/cmd/service_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestServiceEnvVars tests environment variable parsing
|
||||
func TestServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVars []string
|
||||
expected map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid single env var",
|
||||
envVars: []string{"LOG_LEVEL=debug"},
|
||||
expected: map[string]string{
|
||||
"LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid multiple env vars",
|
||||
envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"},
|
||||
expected: map[string]string{
|
||||
"LOG_LEVEL": "debug",
|
||||
"CUSTOM_VAR": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Env var with spaces",
|
||||
envVars: []string{" KEY = value "},
|
||||
expected: map[string]string{
|
||||
"KEY": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid format - no equals",
|
||||
envVars: []string{"INVALID"},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid format - empty key",
|
||||
envVars: []string{"=value"},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Empty value is valid",
|
||||
envVars: []string{"KEY="},
|
||||
expected: map[string]string{
|
||||
"KEY": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
envVars: []string{},
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "Empty string in slice",
|
||||
envVars: []string{"", "KEY=value", ""},
|
||||
expected: map[string]string{"KEY": "value"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseServiceEnvVars(tt.envVars)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceConfigWithEnvVars tests service config creation with env vars
|
||||
func TestServiceConfigWithEnvVars(t *testing.T) {
|
||||
originalServiceName := serviceName
|
||||
originalServiceEnvVars := serviceEnvVars
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
serviceEnvVars = originalServiceEnvVars
|
||||
}()
|
||||
|
||||
serviceName = "test-service"
|
||||
serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"}
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test-service", cfg.Name)
|
||||
assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"])
|
||||
assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"])
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"])
|
||||
}
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapUDPConn(conn),
|
||||
UDPConn: nbnet.WrapPacketConn(conn),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
|
||||
@@ -7,15 +7,16 @@ import (
|
||||
)
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
wrapped, ok := m.params.UDPConn.(*UDPConn)
|
||||
if !ok {
|
||||
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
|
||||
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
|
||||
conn.RemoveAddress(addr)
|
||||
return
|
||||
}
|
||||
|
||||
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
|
||||
if !ok {
|
||||
return
|
||||
// Userspace mode: UDPConn wrapper around nbnet.PacketConn
|
||||
if wrapped, ok := m.params.UDPConn.(*UDPConn); ok {
|
||||
if conn, ok := wrapped.GetPacketConn().(*nbnet.PacketConn); ok {
|
||||
conn.RemoveAddress(addr)
|
||||
}
|
||||
}
|
||||
|
||||
nbnetConn.RemoveAddress(addr)
|
||||
}
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "nb0"
|
||||
const WgInterfaceDefault = "wt0"
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
@@ -99,8 +100,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var udpConn net.PacketConn = rawSock
|
||||
if !nbnet.AdvancedRouting() {
|
||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
||||
}
|
||||
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: rawSock,
|
||||
UDPConn: udpConn,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
)
|
||||
|
||||
var defaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
|
||||
@@ -1393,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("nb%d", i)
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
|
||||
@@ -33,6 +33,15 @@ func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAd
|
||||
|
||||
}
|
||||
|
||||
// Add this method to the Manager struct
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
listener, exists := m.peers[peerConnID]
|
||||
return listener, exists
|
||||
}
|
||||
|
||||
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
@@ -51,7 +60,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
listener, exists := mgr.GetPeerListener(peerCfg1.PeerConnID)
|
||||
if !exists {
|
||||
t.Fatalf("peer listener not found")
|
||||
}
|
||||
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
@@ -128,11 +142,21 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
listener, exists := mgr.GetPeerListener(peerCfg1.PeerConnID)
|
||||
if !exists {
|
||||
t.Fatalf("peer listener for peer1 not found")
|
||||
}
|
||||
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
listener, exists = mgr.GetPeerListener(peerCfg2.PeerConnID)
|
||||
if !exists {
|
||||
t.Fatalf("peer listener for peer2 not found")
|
||||
}
|
||||
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "nb0"
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
|
||||
@@ -24,7 +24,7 @@ type WorkerRelay struct {
|
||||
isController bool
|
||||
config ConnConfig
|
||||
conn *Conn
|
||||
relayManager relayClient.ManagerService
|
||||
relayManager *relayClient.Manager
|
||||
|
||||
relayedConn net.Conn
|
||||
relayLock sync.Mutex
|
||||
@@ -34,7 +34,7 @@ type WorkerRelay struct {
|
||||
wgWatcher *WGWatcher
|
||||
}
|
||||
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
|
||||
r := &WorkerRelay{
|
||||
peerCtx: ctx,
|
||||
log: log,
|
||||
|
||||
@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "nb0",
|
||||
name: "wt0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
|
||||
@@ -1330,13 +1330,6 @@ func (x *PeerState) GetRelayAddress() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PeerState) GetConnectionType() string {
|
||||
if x.Relayed {
|
||||
return "Relayed"
|
||||
}
|
||||
return "P2P"
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
type LocalPeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
|
||||
@@ -203,13 +203,18 @@ func mapPeers(
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := ""
|
||||
connType := "P2P"
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
|
||||
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
@@ -219,7 +224,6 @@ func mapPeers(
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = pbPeerState.GetConnectionType()
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
@@ -540,7 +544,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
|
||||
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter, connType string) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
@@ -569,7 +573,7 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
|
||||
if connectionTypeFilter != "" && !strings.EqualFold(connType, connectionTypeFilter) {
|
||||
connectionTypeEval = true
|
||||
}
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -63,7 +63,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250724151510-c007bc6b392c
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250724151510-c007bc6b392c h1:OtX903X0FKEE+fcsp/P2701md7X/xbi/W/ojWIJNKSk=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250724151510-c007bc6b392c/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
x-default: &default
|
||||
restart: 'unless-stopped'
|
||||
logging:
|
||||
driver: 'json-file'
|
||||
options:
|
||||
max-size: '500m'
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- 80:80
|
||||
- 443:443
|
||||
@@ -27,16 +35,11 @@ services:
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
volumes:
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
<<: *default
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
ports:
|
||||
@@ -44,16 +47,11 @@ services:
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
<<: *default
|
||||
image: netbirdio/relay:$NETBIRD_RELAY_TAG
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- NB_LOG_LEVEL=info
|
||||
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
|
||||
@@ -62,16 +60,11 @@ services:
|
||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
ports:
|
||||
- $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Management
|
||||
management:
|
||||
<<: *default
|
||||
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- dashboard
|
||||
volumes:
|
||||
@@ -90,19 +83,14 @@ services:
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
]
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
environment:
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||
|
||||
# Coturn
|
||||
coturn:
|
||||
<<: *default
|
||||
image: coturn/coturn:$COTURN_TAG
|
||||
restart: unless-stopped
|
||||
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
@@ -111,11 +99,6 @@ services:
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
x-default: &default
|
||||
restart: 'unless-stopped'
|
||||
logging:
|
||||
driver: 'json-file'
|
||||
options:
|
||||
max-size: '500m'
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
# Endpoints
|
||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
@@ -28,16 +36,11 @@ services:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-dashboard.rule=Host(`$NETBIRD_DOMAIN`)
|
||||
- traefik.http.services.netbird-dashboard.loadbalancer.server.port=80
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
<<: *default
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
labels:
|
||||
@@ -45,27 +48,17 @@ services:
|
||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
<<: *default
|
||||
image: netbirdio/relay:$NETBIRD_RELAY_TAG
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- NB_LOG_LEVEL=info
|
||||
- NB_LISTEN_ADDRESS=:33080
|
||||
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
||||
# todo: change to a secure secret
|
||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-relay.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/relay`)
|
||||
@@ -73,8 +66,8 @@ services:
|
||||
|
||||
# Management
|
||||
management:
|
||||
<<: *default
|
||||
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- dashboard
|
||||
volumes:
|
||||
@@ -99,30 +92,20 @@ services:
|
||||
- traefik.http.routers.netbird-management.service=netbird-management
|
||||
- traefik.http.services.netbird-management.loadbalancer.server.port=33073
|
||||
- traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
environment:
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||
|
||||
# Coturn
|
||||
coturn:
|
||||
<<: *default
|
||||
image: coturn/coturn:$COTURN_TAG
|
||||
restart: unless-stopped
|
||||
domainname: $TURN_DOMAIN
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "500m"
|
||||
max-file: "2"
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
|
||||
@@ -780,7 +780,6 @@ EOF
|
||||
|
||||
renderDockerCompose() {
|
||||
cat <<EOF
|
||||
version: "3.4"
|
||||
services:
|
||||
# Caddy reverse proxy
|
||||
caddy:
|
||||
|
||||
@@ -101,7 +101,7 @@ type Manager interface {
|
||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManager() idp.Manager
|
||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -12,34 +13,44 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
|
||||
// UpdateIntegratedValidator updates the integrated validator groups for a specified account.
|
||||
// It retrieves the account associated with the provided userID, then updates the integrated validator groups
|
||||
// with the provided list of group ids. The updated account is then saved.
|
||||
//
|
||||
// Parameters:
|
||||
// - accountID: The ID of the account for which integrated validator groups are to be updated.
|
||||
// - userID: The ID of the user whose account is being updated.
|
||||
// - validator: The validator type to use, or empty to remove.
|
||||
// - groups: A slice of strings representing the ids of integrated validator groups to be updated.
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if any occurred during the process, otherwise returns nil
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
ok, err := am.GroupValidation(ctx, accountID, groups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
|
||||
return err
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error {
|
||||
if validator != "" && len(groups) == 0 {
|
||||
return fmt.Errorf("at least one group must be specified for validator")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("invalid groups")
|
||||
return errors.New("invalid groups")
|
||||
if validator != "" {
|
||||
ok, err := am.GroupValidation(ctx, accountID, groups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("invalid groups")
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
} else {
|
||||
// ensure groups is empty
|
||||
groups = []string{}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
a, err := transaction.GetAccountByUser(ctx, userID)
|
||||
a, err := transaction.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -52,6 +63,8 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
extra = &types.ExtraSettings{}
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
|
||||
extra.IntegratedValidator = validator
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return transaction.SaveAccount(ctx, a)
|
||||
})
|
||||
@@ -99,7 +112,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
|
||||
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||
}
|
||||
|
||||
type MockIntegratedValidator struct {
|
||||
@@ -118,7 +131,7 @@ func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
func (a MockIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for _, peer := range peers {
|
||||
validatedPeers[peer.ID] = struct{}{}
|
||||
@@ -134,7 +147,7 @@ func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID strin
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||
func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string, extraSettings *types.ExtraSettings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ type IntegratedValidator interface {
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string) error
|
||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||
SetPeerInvalidationListener(fn func(accountID string))
|
||||
Stop(ctx context.Context)
|
||||
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
|
||||
|
||||
@@ -102,7 +102,7 @@ type MockAccountManager struct {
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
@@ -769,10 +769,10 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
if am.UpdateIntegratedValidatorGroupsFunc != nil {
|
||||
return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups)
|
||||
// UpdateIntegratedValidator mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error {
|
||||
if am.UpdateIntegratedValidatorFunc != nil {
|
||||
return am.UpdateIntegratedValidatorFunc(ctx, accountID, userID, validator, groups)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented")
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -87,7 +88,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -412,7 +413,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
groups[groupID] = group.Peers
|
||||
}
|
||||
|
||||
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1036,7 +1037,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -1156,7 +1157,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1183,6 +1184,8 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
// UpdateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
@@ -1204,7 +1207,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err)
|
||||
return
|
||||
@@ -1288,6 +1291,8 @@ type bufferUpdate struct {
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
@@ -1337,7 +1342,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
|
||||
return
|
||||
@@ -1571,7 +1576,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
}
|
||||
}
|
||||
|
||||
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil {
|
||||
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -11,14 +11,17 @@ import (
|
||||
// Scheduler is an interface which implementations can schedule and cancel jobs
|
||||
type Scheduler interface {
|
||||
Cancel(ctx context.Context, IDs []string)
|
||||
CancelAll(ctx context.Context)
|
||||
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||
IsSchedulerRunning(ID string) bool
|
||||
}
|
||||
|
||||
// MockScheduler is a mock implementation of Scheduler
|
||||
type MockScheduler struct {
|
||||
CancelFunc func(ctx context.Context, IDs []string)
|
||||
ScheduleFunc func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||
CancelFunc func(ctx context.Context, IDs []string)
|
||||
CancelAllFunc func(ctx context.Context)
|
||||
ScheduleFunc func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||
IsSchedulerRunningFunc func(ID string) bool
|
||||
}
|
||||
|
||||
// Cancel mocks the Cancel function of the Scheduler interface
|
||||
@@ -30,6 +33,15 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) {
|
||||
log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ")
|
||||
}
|
||||
|
||||
// CancelAll mocks the CancelAll function of the Scheduler interface
|
||||
func (mock *MockScheduler) CancelAll(ctx context.Context) {
|
||||
if mock.CancelAllFunc != nil {
|
||||
mock.CancelAllFunc(ctx)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Warnf("MockScheduler doesn't have CancelAll function defined ")
|
||||
}
|
||||
|
||||
// Schedule mocks the Schedule function of the Scheduler interface
|
||||
func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
if mock.ScheduleFunc != nil {
|
||||
@@ -40,7 +52,9 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st
|
||||
}
|
||||
|
||||
func (mock *MockScheduler) IsSchedulerRunning(ID string) bool {
|
||||
// MockScheduler does not implement IsSchedulerRunning, so we return false
|
||||
if mock.IsSchedulerRunningFunc != nil {
|
||||
return mock.IsSchedulerRunningFunc(ID)
|
||||
}
|
||||
log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined")
|
||||
return false
|
||||
}
|
||||
@@ -52,6 +66,15 @@ type DefaultScheduler struct {
|
||||
mu *sync.Mutex
|
||||
}
|
||||
|
||||
func (wm *DefaultScheduler) CancelAll(ctx context.Context) {
|
||||
wm.mu.Lock()
|
||||
defer wm.mu.Unlock()
|
||||
|
||||
for id := range wm.jobs {
|
||||
wm.cancel(ctx, id)
|
||||
}
|
||||
}
|
||||
|
||||
// NewDefaultScheduler creates an instance of a DefaultScheduler
|
||||
func NewDefaultScheduler() *DefaultScheduler {
|
||||
return &DefaultScheduler{
|
||||
|
||||
@@ -75,6 +75,38 @@ func TestScheduler_Cancel(t *testing.T) {
|
||||
assert.NotNil(t, scheduler.jobs[jobID2])
|
||||
}
|
||||
|
||||
func TestScheduler_CancelAll(t *testing.T) {
|
||||
jobID1 := "test-scheduler-job-1"
|
||||
jobID2 := "test-scheduler-job-2"
|
||||
scheduler := NewDefaultScheduler()
|
||||
tChan := make(chan struct{})
|
||||
p := []string{jobID1, jobID2}
|
||||
scheduletime := 2 * time.Millisecond
|
||||
sleepTime := 4 * time.Millisecond
|
||||
if runtime.GOOS == "windows" {
|
||||
// sleep and ticker are slower on windows see https://github.com/golang/go/issues/44343
|
||||
sleepTime = 20 * time.Millisecond
|
||||
}
|
||||
|
||||
scheduler.Schedule(context.Background(), scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) {
|
||||
tt := p[0]
|
||||
<-tChan
|
||||
t.Logf("job %s", tt)
|
||||
return scheduletime, true
|
||||
})
|
||||
scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
|
||||
return scheduletime, true
|
||||
})
|
||||
|
||||
time.Sleep(sleepTime)
|
||||
assert.Len(t, scheduler.jobs, 2)
|
||||
scheduler.CancelAll(context.Background())
|
||||
close(tChan)
|
||||
p = []string{}
|
||||
time.Sleep(sleepTime)
|
||||
assert.Len(t, scheduler.jobs, 0)
|
||||
}
|
||||
|
||||
func TestScheduler_Schedule(t *testing.T) {
|
||||
jobID := "test-scheduler-job-1"
|
||||
scheduler := NewDefaultScheduler()
|
||||
|
||||
@@ -77,6 +77,8 @@ type ExtraSettings struct {
|
||||
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
|
||||
PeerApprovalEnabled bool
|
||||
|
||||
// IntegratedValidator is the string enum for the integrated validator type
|
||||
IntegratedValidator string
|
||||
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
|
||||
IntegratedValidatorGroups []string `gorm:"serializer:json"`
|
||||
|
||||
@@ -93,5 +95,10 @@ func (e *ExtraSettings) Copy() *ExtraSettings {
|
||||
return &ExtraSettings{
|
||||
PeerApprovalEnabled: e.PeerApprovalEnabled,
|
||||
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...),
|
||||
IntegratedValidator: e.IntegratedValidator,
|
||||
FlowEnabled: e.FlowEnabled,
|
||||
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
|
||||
FlowENCollectionEnabled: e.FlowENCollectionEnabled,
|
||||
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,8 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
log.WithContext(ctx).Debugf("sending update to peer %s, checks: %s", peerID, update.Update.Checks)
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
|
||||
@@ -292,7 +292,7 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect(ctx)
|
||||
if err != nil {
|
||||
if err = relayClient.Connect(ctx); err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func(_ string) {
|
||||
@@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Fatalf("timeout waiting for client to disconnect")
|
||||
log.Errorf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn(ctx, "bob")
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
const (
|
||||
DefaultConnectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
@@ -25,16 +25,18 @@ type dialResult struct {
|
||||
}
|
||||
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
return &RaceDial{
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
connectionTimeout: connectionTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
|
||||
@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
rd := NewRaceDial(logger, serverURL)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error with empty dialers, got nil")
|
||||
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
||||
protocolStr: proto,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
protocolStr: "proto2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
if conn.RemoteAddr().Network() != proto2 {
|
||||
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func TestRaceDialTimeout(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
connectionTimeout = 3 * time.Second
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
|
||||
protocolStr: "protocol2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
protocolStr: proto2,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
const (
|
||||
// TODO: make it configurable, the manager should validate all configurable parameters
|
||||
reconnectingTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
|
||||
@@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
// ManagerService is the interface for the relay manager.
|
||||
type ManagerService interface {
|
||||
Serve() error
|
||||
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
|
||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||
RelayInstanceAddress() (string, error)
|
||||
ServerURLs() []string
|
||||
HasRelayAddress() bool
|
||||
UpdateToken(token *relayAuth.Token) error
|
||||
}
|
||||
|
||||
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
||||
// and automatically reconnect to them in case disconnection.
|
||||
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
)
|
||||
|
||||
func TestEmptyURL(t *testing.T) {
|
||||
mgr := NewManager(context.Background(), nil, "alice")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
mgr := NewManager(ctx, nil, "alice")
|
||||
err := mgr.Serve()
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
@@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
func TestForeignAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
keepUnusedServerTime = 2 * time.Second
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
@@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
// Set up a disconnect listener to track when foreign server disconnects
|
||||
foreignServerURL := toURL(srvCfg2)[0]
|
||||
disconnected := make(chan struct{})
|
||||
onDisconnect := func() {
|
||||
select {
|
||||
case disconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
|
||||
t.Fatalf("should have failed to open connection to another peer")
|
||||
}
|
||||
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||
// Add the disconnect listener after the connection attempt
|
||||
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
|
||||
t.Logf("failed to add close listener (expected if connection failed): %s", err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to happen
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
time.Sleep(timeout)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
t.Log("foreign relay connection cleaned up successfully")
|
||||
case <-time.After(timeout):
|
||||
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
@@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
srvCfg := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
@@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
if err := srv.Listen(srvCfg); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -4,38 +4,76 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Mutex to protect global variable access in tests
|
||||
var testMutex sync.Mutex
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
t.Error("unexpected timeout")
|
||||
case <-time.After(1 * time.Second):
|
||||
|
||||
// Test passes if no timeout received
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
// Test passes if timeout is received
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout not received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
r.Heartbeat()
|
||||
|
||||
@@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := heartbeatTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
healthCheckInterval = originalInterval
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
|
||||
@@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
sender := NewSender(log.WithField("test_name", tc.name))
|
||||
go sender.StartHealthCheck(ctx)
|
||||
senderExit := make(chan struct{})
|
||||
go func() {
|
||||
sender.StartHealthCheck(ctx)
|
||||
close(senderExit)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
responded := false
|
||||
@@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-senderExit:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("sender did not exit in time")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -20,12 +20,12 @@ type Metrics struct {
|
||||
TransferBytesRecv metric.Int64Counter
|
||||
AuthenticationTime metric.Float64Histogram
|
||||
PeerStoreTime metric.Float64Histogram
|
||||
|
||||
peers metric.Int64UpDownCounter
|
||||
peerActivityChan chan string
|
||||
peerLastActive map[string]time.Time
|
||||
mutexActivity sync.Mutex
|
||||
ctx context.Context
|
||||
peerReconnections metric.Int64Counter
|
||||
peers metric.Int64UpDownCounter
|
||||
peerActivityChan chan string
|
||||
peerLastActive map[string]time.Time
|
||||
mutexActivity sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
@@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
|
||||
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Metrics{
|
||||
Meter: meter,
|
||||
TransferBytesSent: bytesSent,
|
||||
@@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
AuthenticationTime: authTime,
|
||||
PeerStoreTime: peerStoreTime,
|
||||
peers: peers,
|
||||
peerReconnections: peerReconnections,
|
||||
|
||||
ctx: ctx,
|
||||
peerActivityChan: make(chan string, 10),
|
||||
@@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
|
||||
delete(m.peerLastActive, id)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPeerReconnection() {
|
||||
m.peerReconnections.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// PeerActivity increases the active connections
|
||||
func (m *Metrics) PeerActivity(peerID string) {
|
||||
select {
|
||||
|
||||
@@ -18,12 +18,9 @@ type Listener struct {
|
||||
TLSConfig *tls.Config
|
||||
|
||||
listener *quic.Listener
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
|
||||
quicCfg := &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: 1452,
|
||||
@@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
|
||||
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
|
||||
conn := NewConn(session)
|
||||
l.acceptFn(conn)
|
||||
acceptFn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,9 @@ type Peer struct {
|
||||
notifier *store.PeerNotifier
|
||||
|
||||
peersListener *store.Listener
|
||||
|
||||
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
|
||||
notificationMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewPeer creates a new Peer instance and prepare custom logging
|
||||
@@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
|
||||
}
|
||||
|
||||
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
|
||||
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
|
||||
|
||||
// collect online peers to response back to the caller
|
||||
p.notificationMutex.Lock()
|
||||
defer p.notificationMutex.Unlock()
|
||||
|
||||
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
|
||||
if len(onlinePeers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
p.log.Debugf("response with %d online peers", len(onlinePeers))
|
||||
p.sendPeersOnline(onlinePeers)
|
||||
}
|
||||
@@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
||||
}
|
||||
|
||||
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
||||
p.notificationMutex.Lock()
|
||||
defer p.notificationMutex.Unlock()
|
||||
|
||||
msgs, err := messages.MarshalPeersWentOffline(peers)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||
|
||||
@@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) {
|
||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||
}
|
||||
|
||||
peerStore := store.NewStore()
|
||||
r := &Relay{
|
||||
metrics: m,
|
||||
metricsCancel: metricsCancel,
|
||||
validator: config.AuthValidator,
|
||||
instanceURL: config.instanceURL,
|
||||
store: peerStore,
|
||||
notifier: store.NewPeerNotifier(peerStore),
|
||||
store: store.NewStore(),
|
||||
notifier: store.NewPeerNotifier(),
|
||||
}
|
||||
|
||||
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||
@@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
storeTime := time.Now()
|
||||
r.store.AddPeer(peer)
|
||||
if isReconnection := r.store.AddPeer(peer); isReconnection {
|
||||
r.metrics.RecordPeerReconnection()
|
||||
}
|
||||
r.notifier.PeerCameOnline(peer.ID())
|
||||
|
||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
go func() {
|
||||
peer.Work()
|
||||
r.notifier.PeerWentOffline(peer.ID())
|
||||
r.store.DeletePeer(peer)
|
||||
if deleted := r.store.DeletePeer(peer); deleted {
|
||||
r.notifier.PeerWentOffline(peer.ID())
|
||||
}
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
}()
|
||||
|
||||
@@ -7,24 +7,27 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
ctx context.Context
|
||||
store *Store
|
||||
type event struct {
|
||||
peerID messages.PeerID
|
||||
online bool
|
||||
}
|
||||
|
||||
onlineChan chan messages.PeerID
|
||||
offlineChan chan messages.PeerID
|
||||
type Listener struct {
|
||||
ctx context.Context
|
||||
|
||||
eventChan chan *event
|
||||
interestedPeersForOffline map[messages.PeerID]struct{}
|
||||
interestedPeersForOnline map[messages.PeerID]struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newListener(ctx context.Context, store *Store) *Listener {
|
||||
func newListener(ctx context.Context) *Listener {
|
||||
l := &Listener{
|
||||
ctx: ctx,
|
||||
store: store,
|
||||
ctx: ctx,
|
||||
|
||||
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
||||
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
||||
// important to use a single channel for offline and online events because with it we can ensure all events
|
||||
// will be processed in the order they were sent
|
||||
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
|
||||
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
|
||||
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
|
||||
}
|
||||
@@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener {
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
|
||||
availablePeers := make([]messages.PeerID, 0)
|
||||
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
@@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer
|
||||
l.interestedPeersForOnline[id] = struct{}{}
|
||||
l.interestedPeersForOffline[id] = struct{}{}
|
||||
}
|
||||
|
||||
// collect online peers to response back to the caller
|
||||
for _, id := range peerIDs {
|
||||
_, ok := l.store.Peer(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
availablePeers = append(availablePeers, id)
|
||||
}
|
||||
return availablePeers
|
||||
}
|
||||
|
||||
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
||||
@@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
||||
for _, id := range peerIDs {
|
||||
delete(l.interestedPeersForOffline, id)
|
||||
delete(l.interestedPeersForOnline, id)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
|
||||
select {
|
||||
case <-l.ctx.Done():
|
||||
return
|
||||
case pID := <-l.onlineChan:
|
||||
peers := make([]messages.PeerID, 0)
|
||||
peers = append(peers, pID)
|
||||
|
||||
for len(l.onlineChan) > 0 {
|
||||
pID = <-l.onlineChan
|
||||
peers = append(peers, pID)
|
||||
case e := <-l.eventChan:
|
||||
peersOffline := make([]messages.PeerID, 0)
|
||||
peersOnline := make([]messages.PeerID, 0)
|
||||
if e.online {
|
||||
peersOnline = append(peersOnline, e.peerID)
|
||||
} else {
|
||||
peersOffline = append(peersOffline, e.peerID)
|
||||
}
|
||||
|
||||
onPeersComeOnline(peers)
|
||||
case pID := <-l.offlineChan:
|
||||
peers := make([]messages.PeerID, 0)
|
||||
peers = append(peers, pID)
|
||||
|
||||
for len(l.offlineChan) > 0 {
|
||||
pID = <-l.offlineChan
|
||||
peers = append(peers, pID)
|
||||
// Drain the channel to collect all events
|
||||
for len(l.eventChan) > 0 {
|
||||
e = <-l.eventChan
|
||||
if e.online {
|
||||
peersOnline = append(peersOnline, e.peerID)
|
||||
} else {
|
||||
peersOffline = append(peersOffline, e.peerID)
|
||||
}
|
||||
}
|
||||
|
||||
onPeersWentOffline(peers)
|
||||
if len(peersOnline) > 0 {
|
||||
onPeersComeOnline(peersOnline)
|
||||
}
|
||||
if len(peersOffline) > 0 {
|
||||
onPeersWentOffline(peersOffline)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
|
||||
|
||||
if _, ok := l.interestedPeersForOffline[peerID]; ok {
|
||||
select {
|
||||
case l.offlineChan <- peerID:
|
||||
case l.eventChan <- &event{
|
||||
peerID: peerID,
|
||||
online: false,
|
||||
}:
|
||||
case <-l.ctx.Done():
|
||||
}
|
||||
}
|
||||
@@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
|
||||
|
||||
if _, ok := l.interestedPeersForOnline[peerID]; ok {
|
||||
select {
|
||||
case l.onlineChan <- peerID:
|
||||
case l.eventChan <- &event{
|
||||
peerID: peerID,
|
||||
online: true,
|
||||
}:
|
||||
case <-l.ctx.Done():
|
||||
}
|
||||
|
||||
delete(l.interestedPeersForOnline, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,15 +8,12 @@ import (
|
||||
)
|
||||
|
||||
type PeerNotifier struct {
|
||||
store *Store
|
||||
|
||||
listeners map[*Listener]context.CancelFunc
|
||||
listenersMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPeerNotifier(store *Store) *PeerNotifier {
|
||||
func NewPeerNotifier() *PeerNotifier {
|
||||
pn := &PeerNotifier{
|
||||
store: store,
|
||||
listeners: make(map[*Listener]context.CancelFunc),
|
||||
}
|
||||
return pn
|
||||
@@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
|
||||
|
||||
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
listener := newListener(ctx, pn.store)
|
||||
listener := newListener(ctx)
|
||||
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
|
||||
|
||||
pn.listenersMutex.Lock()
|
||||
|
||||
@@ -26,7 +26,9 @@ func NewStore() *Store {
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the store
|
||||
func (s *Store) AddPeer(peer IPeer) {
|
||||
// If the peer already exists, it will be replaced and the old peer will be closed
|
||||
// Returns true if the peer was replaced, false if it was added for the first time.
|
||||
func (s *Store) AddPeer(peer IPeer) bool {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
odlPeer, ok := s.peers[peer.ID()]
|
||||
@@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) {
|
||||
}
|
||||
|
||||
s.peers[peer.ID()] = peer
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeletePeer deletes a peer from the store
|
||||
func (s *Store) DeletePeer(peer IPeer) {
|
||||
func (s *Store) DeletePeer(peer IPeer) bool {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
dp, ok := s.peers[peer.ID()]
|
||||
if !ok {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if dp != peer {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
delete(s.peers, peer.ID())
|
||||
return true
|
||||
}
|
||||
|
||||
// Peer returns a peer by its ID
|
||||
@@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer {
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
|
||||
s.peersLock.RLock()
|
||||
defer s.peersLock.RUnlock()
|
||||
|
||||
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
|
||||
|
||||
listener.AddInterestedPeers(peerIDs)
|
||||
|
||||
// Check for currently online peers
|
||||
for _, id := range peerIDs {
|
||||
if _, ok := s.peers[id]; ok {
|
||||
onlinePeers = append(onlinePeers, id)
|
||||
}
|
||||
}
|
||||
|
||||
return onlinePeers
|
||||
}
|
||||
|
||||
@@ -120,17 +120,8 @@ func (c *UDPConn) Close() error {
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
}
|
||||
|
||||
// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality
|
||||
func WrapUDPConn(conn *net.UDPConn) *UDPConn {
|
||||
return &UDPConn{
|
||||
UDPConn: conn,
|
||||
ID: GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||
func (c *UDPConn) RemoveAddress(addr string) {
|
||||
func (c *PacketConn) RemoveAddress(addr string) {
|
||||
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||
return
|
||||
}
|
||||
@@ -159,6 +150,16 @@ func (c *UDPConn) RemoveAddress(addr string) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality
|
||||
func WrapPacketConn(conn net.PacketConn) *PacketConn {
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
ID: GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
||||
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking
|
||||
func WrapUDPConn(conn *net.UDPConn) *net.UDPConn {
|
||||
// WrapPacketConn on iOS just returns the original connection since iOS handles its own networking
|
||||
func WrapPacketConn(conn *net.UDPConn) *net.UDPConn {
|
||||
return conn
|
||||
}
|
||||
|
||||
15
util/runtime.go
Normal file
15
util/runtime.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package util
|
||||
|
||||
import "runtime"
|
||||
|
||||
func GetCallerName() string {
|
||||
pc, _, _, ok := runtime.Caller(2)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
fn := runtime.FuncForPC(pc)
|
||||
if fn == nil {
|
||||
return "unknown"
|
||||
}
|
||||
return fn.Name()
|
||||
}
|
||||
Reference in New Issue
Block a user