Compare commits

..

1 Commits

Author SHA1 Message Date
Viktor Liu
b0c3818e06 Add set subcommand to set config flags 2025-07-02 14:28:33 +02:00
166 changed files with 3078 additions and 6373 deletions

View File

@@ -43,7 +43,7 @@ jobs:
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -14,9 +14,6 @@
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
@@ -32,13 +29,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
<br/>
</strong>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
</a>
</p>
@@ -50,9 +47,10 @@
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open Source Network Security in a Single Platform
### Open-Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -64,9 +64,7 @@ type Client struct {
}
// NewClient instantiate a new Client
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
@@ -205,10 +203,8 @@ func (c *Client) Networks() *NetworkArray {
continue
}
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
@@ -218,7 +214,7 @@ func (c *Client) Networks() *NetworkArray {
}
network := Network{
Name: string(id),
Network: netStr,
Network: routes[0].Network.String(),
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}

View File

@@ -1,26 +0,0 @@
//go:build android
package android
import (
"fmt"
_ "unsafe"
)
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
// In Android version 11 and earlier, pidfd-related system calls
// are not allowed by the seccomp policy, which causes crashes due
// to SIGSYS signals.
//go:linkname checkPidfdOnce os.checkPidfdOnce
var checkPidfdOnce func() error
func execWorkaround(androidSDKVersion int) {
if androidSDKVersion > 30 { // above Android 11
return
}
checkPidfdOnce = func() error {
return fmt.Errorf("unsupported Android version")
}
}

View File

@@ -17,18 +17,10 @@ import (
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const errCloseConnection = "Failed to close connection: %v"
var (
logFileCount uint32
systemInfoFlag bool
uploadBundleFlag bool
uploadBundleURLFlag string
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -96,13 +88,12 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -114,7 +105,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if uploadBundleFlag {
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -232,13 +223,12 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -265,7 +255,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if uploadBundleFlag {
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -385,15 +375,3 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}

210
client/cmd/flags.go Normal file
View File

@@ -0,0 +1,210 @@
package cmd
import (
"fmt"
"time"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/management/domain"
"github.com/spf13/cobra"
)
// SharedFlags contains all configuration flags that are common between up and set commands
type SharedFlags struct {
// Network configuration
InterfaceName string
WireguardPort uint16
NATExternalIPs []string
CustomDNSAddress string
ExtraIFaceBlackList []string
DNSLabels []string
DNSRouteInterval time.Duration
// Feature flags
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
AutoConnectDisabled bool
NetworkMonitor bool
LazyConnEnabled bool
// System flags
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
// Login-specific (only for up command)
NoBrowser bool
}
// AddSharedFlags adds all shared configuration flags to the given command
func AddSharedFlags(cmd *cobra.Command, flags *SharedFlags) {
// Network configuration flags
cmd.PersistentFlags().StringVar(&flags.InterfaceName, interfaceNameFlag, iface.WgInterfaceDefault,
"Wireguard interface name")
cmd.PersistentFlags().Uint16Var(&flags.WireguardPort, wireguardPortFlag, iface.DefaultWgPort,
"Wireguard interface listening port")
cmd.PersistentFlags().StringSliceVar(&flags.NATExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces. `+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
`An empty string "" clears the previous configuration. `+
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
`or --external-ip-map ""`)
cmd.PersistentFlags().StringVar(&flags.CustomDNSAddress, dnsResolverAddress, "",
`Sets a custom address for NetBird's local DNS resolver. `+
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
`An empty string "" clears the previous configuration. `+
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`)
cmd.PersistentFlags().StringSliceVar(&flags.ExtraIFaceBlackList, extraIFaceBlackListFlag, nil,
"Extra list of default interfaces to ignore for listening")
cmd.PersistentFlags().StringSliceVar(&flags.DNSLabels, dnsLabelsFlag, nil,
`Sets DNS labels. `+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`)
cmd.PersistentFlags().DurationVar(&flags.DNSRouteInterval, dnsRouteIntervalFlag, time.Minute,
"DNS route update interval")
// Feature flags
cmd.PersistentFlags().BoolVar(&flags.RosenpassEnabled, enableRosenpassFlag, false,
"[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
cmd.PersistentFlags().BoolVar(&flags.RosenpassPermissive, rosenpassPermissiveFlag, false,
"[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
cmd.PersistentFlags().BoolVar(&flags.ServerSSHAllowed, serverSSHAllowedFlag, false,
"Allow SSH server on peer. If enabled, the SSH server will be permitted")
cmd.PersistentFlags().BoolVar(&flags.AutoConnectDisabled, disableAutoConnectFlag, false,
"Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
cmd.PersistentFlags().BoolVarP(&flags.NetworkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`)
cmd.PersistentFlags().BoolVar(&flags.LazyConnEnabled, enableLazyConnectionFlag, false,
"[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
// System flags (from system.go)
cmd.PersistentFlags().BoolVar(&flags.DisableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
cmd.PersistentFlags().BoolVar(&flags.DisableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
cmd.PersistentFlags().BoolVar(&flags.DisableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
cmd.PersistentFlags().BoolVar(&flags.DisableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
cmd.PersistentFlags().BoolVar(&flags.BlockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
cmd.PersistentFlags().BoolVar(&flags.BlockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
}
// AddUpOnlyFlags adds flags that are specific to the up command
func AddUpOnlyFlags(cmd *cobra.Command, flags *SharedFlags) {
cmd.PersistentFlags().BoolVar(&flags.NoBrowser, noBrowserFlag, false, noBrowserDesc)
}
// BuildConfigInput creates an internal.ConfigInput from SharedFlags with Changed() checks
func BuildConfigInput(cmd *cobra.Command, flags *SharedFlags, customDNSAddressConverted []byte) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
CustomDNSAddress: customDNSAddressConverted,
}
// Handle PreSharedKey from root command
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &flags.RosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &flags.RosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &flags.ServerSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(flags.InterfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &flags.InterfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(flags.WireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &flags.NetworkMonitor
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &flags.AutoConnectDisabled
if flags.AutoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
} else {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &flags.DNSRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &flags.DisableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &flags.DisableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &flags.DisableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &flags.DisableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &flags.BlockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &flags.BlockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &flags.LazyConnEnabled
}
if cmd.Flag(externalIPMapFlag).Changed {
ic.NATExternalIPs = flags.NATExternalIPs
}
if cmd.Flag(extraIFaceBlackListFlag).Changed {
ic.ExtraIFaceBlackList = flags.ExtraIFaceBlackList
}
if cmd.Flag(dnsLabelsFlag).Changed {
var err error
ic.DNSLabels, err = domain.FromStringList(flags.DNSLabels)
if err != nil {
return nil, fmt.Errorf("invalid DNS labels: %v", err)
}
}
return &ic, nil
}

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
@@ -37,7 +38,10 @@ const (
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
)
var (
@@ -71,7 +75,10 @@ var (
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{
@@ -142,6 +149,7 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(setCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
@@ -160,25 +168,10 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
`An empty string "" clears the previous configuration. `+
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
`or --external-ip-map ""`,
)
upCmd.PersistentFlags().StringVar(&customDNSAddress, dnsResolverAddress, "",
`Sets a custom address for NetBird's local DNS resolver. `+
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
`An empty string "" clears the previous configuration. `+
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success

161
client/cmd/set.go Normal file
View File

@@ -0,0 +1,161 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/proto"
)
var (
setFlags = &SharedFlags{}
setCmd = &cobra.Command{
Use: "set",
Short: "Update NetBird client configuration",
Long: `Update NetBird client configuration without connecting. Uses the same flags as 'netbird up' but only updates the configuration file.`,
RunE: setFunc,
}
)
func init() {
// Add all shared flags to the set command
AddSharedFlags(setCmd, setFlags)
// Note: We don't add up-only flags like --no-browser to set command
}
func setFunc(cmd *cobra.Command, _ []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
// Validate inputs (reuse validation logic from up.go)
if err := validateNATExternalIPs(setFlags.NATExternalIPs); err != nil {
return err
}
if cmd.Flag(dnsLabelsFlag).Changed {
if _, err := validateDnsLabels(setFlags.DNSLabels); err != nil {
return err
}
}
var customDNSAddressConverted []byte
if cmd.Flag(dnsResolverAddress).Changed {
var err error
customDNSAddressConverted, err = parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
}
// Connect to daemon
ctx := cmd.Context()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
fmt.Printf("Warning: failed to close connection: %v\n", closeErr)
}
}()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SetConfigRequest{}
// Set fields based on changed flags
if cmd.Flag(enableRosenpassFlag).Changed {
req.RosenpassEnabled = &setFlags.RosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
req.RosenpassPermissive = &setFlags.RosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &setFlags.ServerSSHAllowed
}
if cmd.Flag(disableAutoConnectFlag).Changed {
req.DisableAutoConnect = &setFlags.AutoConnectDisabled
}
if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &setFlags.NetworkMonitor
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(setFlags.InterfaceName); err != nil {
return err
}
req.InterfaceName = &setFlags.InterfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
port := int64(setFlags.WireguardPort)
req.WireguardPort = &port
}
if cmd.Flag(dnsResolverAddress).Changed {
customAddr := string(customDNSAddressConverted)
req.CustomDNSAddress = &customAddr
}
if cmd.Flag(extraIFaceBlackListFlag).Changed {
req.ExtraIFaceBlacklist = setFlags.ExtraIFaceBlackList
}
if cmd.Flag(dnsLabelsFlag).Changed {
req.DnsLabels = setFlags.DNSLabels
req.CleanDNSLabels = &[]bool{setFlags.DNSLabels != nil && len(setFlags.DNSLabels) == 0}[0]
}
if cmd.Flag(externalIPMapFlag).Changed {
req.NatExternalIPs = setFlags.NATExternalIPs
req.CleanNATExternalIPs = &[]bool{setFlags.NATExternalIPs != nil && len(setFlags.NATExternalIPs) == 0}[0]
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
req.DnsRouteInterval = durationpb.New(setFlags.DNSRouteInterval)
}
if cmd.Flag(disableClientRoutesFlag).Changed {
req.DisableClientRoutes = &setFlags.DisableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
req.DisableServerRoutes = &setFlags.DisableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
req.DisableDns = &setFlags.DisableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
req.DisableFirewall = &setFlags.DisableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
req.BlockLanAccess = &setFlags.BlockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
req.BlockInbound = &setFlags.BlockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &setFlags.LazyConnEnabled
}
// Send the request
if _, err := client.SetConfig(ctx, req); err != nil {
return fmt.Errorf("update configuration: %w", err)
}
cmd.Println("Configuration updated successfully")
return nil
}

110
client/cmd/set_test.go Normal file
View File

@@ -0,0 +1,110 @@
package cmd
import (
"testing"
"github.com/spf13/cobra"
)
func TestParseBoolArg(t *testing.T) {
tests := []struct {
input string
expected bool
hasError bool
}{
{"true", true, false},
{"True", true, false},
{"1", true, false},
{"yes", true, false},
{"on", true, false},
{"false", false, false},
{"False", false, false},
{"0", false, false},
{"no", false, false},
{"off", false, false},
{"invalid", false, true},
{"maybe", false, true},
{"", false, true},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
result, err := parseBoolArg(test.input)
if test.hasError {
if err == nil {
t.Errorf("Expected error for input %q, but got none", test.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %q: %v", test.input, err)
}
if result != test.expected {
t.Errorf("For input %q, expected %v but got %v", test.input, test.expected, result)
}
}
})
}
}
func TestSetCommandStructure(t *testing.T) {
// Test that the set command has the expected subcommands
expectedSubcommands := []string{
"autoconnect",
"ssh-server",
"network-monitor",
"rosenpass",
"dns",
"dns-interval",
}
actualSubcommands := make([]string, 0, len(setCmd.Commands()))
for _, cmd := range setCmd.Commands() {
actualSubcommands = append(actualSubcommands, cmd.Name())
}
if len(actualSubcommands) != len(expectedSubcommands) {
t.Errorf("Expected %d subcommands, got %d", len(expectedSubcommands), len(actualSubcommands))
}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range actualSubcommands {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected subcommand %q not found", expected)
}
}
}
func TestSetCommandUsage(t *testing.T) {
if setCmd.Use != "set" {
t.Errorf("Expected command use to be 'set', got %q", setCmd.Use)
}
if setCmd.Short != "Set NetBird client configuration" {
t.Errorf("Expected short description to be 'Set NetBird client configuration', got %q", setCmd.Short)
}
}
func TestSubcommandArgRequirements(t *testing.T) {
// Test that all subcommands except dns-interval require exactly 1 argument
subcommands := []*cobra.Command{
setAutoconnectCmd,
setSSHServerCmd,
setNetworkMonitorCmd,
setRosenpassCmd,
setDNSCmd,
setDNSIntervalCmd,
}
for _, cmd := range subcommands {
if cmd.Args == nil {
t.Errorf("Command %q should have Args validation", cmd.Name())
}
}
}

View File

@@ -19,24 +19,3 @@ var (
blockInbound bool
)
func init() {
// Add system flags to upCmd
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
}

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -15,7 +14,6 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -42,6 +40,7 @@ var (
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
upFlags = &SharedFlags{}
upCmd = &cobra.Command{
Use: "up",
@@ -51,26 +50,12 @@ var (
)
func init() {
// Add shared flags to up command
AddSharedFlags(upCmd, upFlags)
// Add up-specific flags
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
AddUpOnlyFlags(upCmd, upFlags)
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -118,7 +103,16 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
ic, err := setupConfig(customDNSAddressConverted, cmd)
// Handle DNS labels validation and assignment to SharedFlags
if cmd.Flag(dnsLabelsFlag).Changed {
var err error
dnsLabelsValidated, err = validateDnsLabels(upFlags.DNSLabels)
if err != nil {
return err
}
}
ic, err := BuildConfigInput(cmd, upFlags, customDNSAddressConverted)
if err != nil {
return fmt.Errorf("setup config: %v", err)
}
@@ -235,92 +229,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,

View File

@@ -1,408 +0,0 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@@ -1,416 +0,0 @@
package uspfilter
import (
"fmt"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
)
// BenchmarkDNATTranslation measures the performance of DNAT operations
func BenchmarkDNATTranslation(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
description string
}{
{
name: "tcp_with_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: true,
description: "TCP packet with DNAT translation enabled",
},
{
name: "tcp_without_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
description: "TCP packet without DNAT (baseline)",
},
{
name: "udp_with_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: true,
description: "UDP packet with DNAT translation enabled",
},
{
name: "udp_without_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
description: "UDP packet without DNAT (baseline)",
},
{
name: "icmp_with_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: true,
description: "ICMP packet with DNAT translation enabled",
},
{
name: "icmp_without_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: false,
description: "ICMP packet without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mapping if needed
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
if sc.setupDNAT {
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Create test packets
srcIP := netip.MustParseAddr("172.16.0.1")
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
// Pre-establish connection for reverse DNAT test
if sc.setupDNAT {
manager.filterOutbound(outboundPacket, 0)
}
b.ResetTimer()
// Benchmark outbound DNAT translation
b.Run("outbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
// Benchmark inbound reverse DNAT translation
if sc.setupDNAT {
b.Run("inbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
manager.filterInbound(packet, 0)
}
})
}
})
}
}
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup multiple DNAT mappings
numMappings := 100
originalIPs := make([]netip.Addr, numMappings)
translatedIPs := make([]netip.Addr, numMappings)
for i := 0; i < numMappings; i++ {
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
require.NoError(b, err)
}
srcIP := netip.MustParseAddr("172.16.0.1")
// Pre-generate packets
outboundPackets := make([][]byte, numMappings)
inboundPackets := make([][]byte, numMappings)
for i := 0; i < numMappings; i++ {
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
// Establish connections
manager.filterOutbound(outboundPackets[i], 0)
}
b.ResetTimer()
b.Run("concurrent_outbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
i++
}
})
})
b.Run("concurrent_inbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
manager.filterInbound(packet, 0)
i++
}
})
})
}
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
func BenchmarkDNATScaling(b *testing.B) {
mappingCounts := []int{1, 10, 100, 1000}
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mappings
for i := 0; i < count; i++ {
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Test with the last mapping added (worst case for lookup)
srcIP := netip.MustParseAddr("172.16.0.1")
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
}
}
// generateDNATTestPacket creates a test packet for DNAT benchmarking
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
tb.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: proto,
}
var transportLayer gopacket.SerializableLayer
switch proto {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
case layers.IPProtocolICMPv4:
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
transportLayer = icmp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(tb, err)
return buf.Bytes()
}
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
func BenchmarkChecksumUpdate(b *testing.B) {
// Create test data for checksum calculations
testData := make([]byte, 64) // Typical packet size for checksum testing
for i := range testData {
testData[i] = byte(i)
}
b.Run("ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
}
})
b.Run("icmp_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = icmpChecksum(testData)
}
})
b.Run("incremental_update", func(b *testing.B) {
oldBytes := []byte{192, 168, 1, 100}
newBytes := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
}
})
}
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Create fresh packet each time to isolate allocation testing
testPacket := make([]byte, len(packet))
copy(testPacket, packet)
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
}
}
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
func BenchmarkDirectIPExtraction(b *testing.B) {
// Create a test packet
srcIP := netip.MustParseAddr("172.16.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
b.Run("direct_byte_access", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Direct extraction from packet bytes
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
}
})
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
// Extract using decoder (traditional method)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
_ = dst
}
})
}
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
func BenchmarkChecksumOptimizations(b *testing.B) {
// Create test IPv4 header (20 bytes)
header := make([]byte, 20)
for i := range header {
header[i] = byte(i)
}
// Clear checksum field
header[10] = 0
header[11] = 0
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(header)
}
})
// Test incremental checksum updates
oldIP := []byte{192, 168, 1, 100}
newIP := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
b.Run("optimized_incremental_update", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
}
})
}

View File

@@ -1,145 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
// Add DNAT mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcPort uint16
dstPort uint16
}{
{"TCP", layers.IPProtocolTCP, 12345, 80},
{"UDP", layers.IPProtocolUDP, 12345, 53},
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test outbound DNAT translation
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound packet (should translate destination)
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, translated, "Outbound packet should be translated")
// Verify destination IP was changed
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
// Test inbound reverse DNAT translation
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet (should reverse translate source)
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, reversed, "Inbound packet should be reverse translated")
// Verify source IP was changed back to original
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
// Test that checksums are recalculated correctly
if tc.protocol != layers.IPProtocolICMPv4 {
// For TCP/UDP, verify the transport checksum was updated
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
}
})
}
}
// parsePacket helper to create a decoder for testing
func parsePacket(t testing.TB, packetData []byte) *decoder {
t.Helper()
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
// Test adding mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
// Verify mapping exists
result, exists := manager.getDNATTranslation(originalIP)
require.True(t, exists)
require.Equal(t, translatedIP, result)
// Test reverse lookup
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
require.True(t, exists)
require.Equal(t, originalIP, reverseResult)
// Test removing mapping
err = manager.RemoveInternalDNATMapping(originalIP)
require.NoError(t, err)
// Verify mapping no longer exists
_, exists = manager.getDNATTranslation(originalIP)
require.False(t, exists)
_, exists = manager.findReverseDNATMapping(translatedIP)
require.False(t, exists)
// Test error cases
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
require.Error(t, err, "Should reject invalid original IP")
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
require.Error(t, err, "Should reject invalid translated IP")
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}

View File

@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.filterOutbound(packetData, 0)
dropped := m.processOutgoingHooks(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -104,12 +104,6 @@ type Manager struct {
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
}
// decoder for packages
@@ -195,7 +189,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
}
m.routingEnabled.Store(false)
@@ -526,6 +519,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -572,14 +581,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// FilterOutBound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size)
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
}
// FilterInbound filters incoming packets
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
return m.filterInbound(packetData, size)
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
}
// UpdateLocalIPs updates the list of local IPs
@@ -587,7 +596,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -609,8 +618,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
return false
}
@@ -714,9 +723,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false
}
// filterInbound implements filtering logic for incoming packets.
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -738,15 +747,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false
}

View File

@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
// Test packet
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.filterOutbound(testOut, 0)
manager.processOutgoingHooks(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(testIn, 0)
manager.dropFilter(testIn, 0)
}
})
}
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.filterOutbound(outPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.filterInbound(inPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
}
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
// Data transfer
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
// Connection teardown
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
}
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
// Pre-generate test packets
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.filterOutbound(outPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
})
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
})

View File

@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)

View File

@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.filterInbound(buf.Bytes(), 0) {
if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.filterOutbound(buf.Bytes(), 0)
result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.filterOutbound(buf.Bytes(), 0)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result)
}
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.filterInbound(testBuf.Bytes(), 0)
drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@@ -1,96 +0,0 @@
package bind
import (
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/monotime"
)
const (
saveFrequency = int64(5 * time.Second)
)
type PeerRecord struct {
Address netip.AddrPort
LastActivity atomic.Int64 // UnixNano timestamp
}
type ActivityRecorder struct {
mu sync.RWMutex
peers map[string]*PeerRecord // publicKey to PeerRecord map
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
}
func NewActivityRecorder() *ActivityRecorder {
return &ActivityRecorder{
peers: make(map[string]*PeerRecord),
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
}
}
// GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
r.mu.RLock()
defer r.mu.RUnlock()
activities := make(map[string]monotime.Time, len(r.peers))
for key, record := range r.peers {
monoTime := record.LastActivity.Load()
activities[key] = monotime.Time(monoTime)
}
return activities
}
// UpsertAddress adds or updates the address for a publicKey
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
r.mu.Lock()
defer r.mu.Unlock()
var record *PeerRecord
record, exists := r.peers[publicKey]
if exists {
delete(r.addrToPeer, record.Address)
record.Address = address
} else {
record = &PeerRecord{
Address: address,
}
record.LastActivity.Store(int64(monotime.Now()))
r.peers[publicKey] = record
}
r.addrToPeer[address] = record
}
func (r *ActivityRecorder) Remove(publicKey string) {
r.mu.Lock()
defer r.mu.Unlock()
if record, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, record.Address)
delete(r.peers, publicKey)
}
}
// record updates LastActivity for the given address using atomic store
func (r *ActivityRecorder) record(address netip.AddrPort) {
r.mu.RLock()
record, ok := r.addrToPeer[address]
r.mu.RUnlock()
if !ok {
log.Warnf("could not find record for address %s", address)
return
}
now := int64(monotime.Now())
last := record.LastActivity.Load()
if now-last < saveFrequency {
return
}
_ = record.LastActivity.CompareAndSwap(last, now)
}

View File

@@ -1,25 +0,0 @@
package bind
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/monotime"
)
func TestActivityRecorder_GetLastActivities(t *testing.T) {
peer := "peer1"
ar := NewActivityRecorder()
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
activities := ar.GetLastActivities()
p, ok := activities[peer]
if !ok {
t.Fatalf("Expected activity for peer %s, but got none", peer)
}
if monotime.Since(p) > 5*time.Second {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
}
}

View File

@@ -1,7 +1,6 @@
package bind
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -52,24 +51,22 @@ type ICEBind struct {
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
activityRecorder: NewActivityRecorder(),
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -103,10 +100,6 @@ func (s *ICEBind) Close() error {
return s.StdNetBind.Close()
}
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
@@ -206,11 +199,6 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
}
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
@@ -269,13 +257,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
if isTransportPkg(buffs, sizes[0]) {
if ep, ok := eps[0].(*Endpoint); ok {
c.activityRecorder.record(ep.AddrPort)
}
}
return 1, nil
}
}
@@ -291,19 +272,3 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
}
msgsPool.Put(msgs)
}
func isTransportPkg(buffers [][]byte, n int) bool {
// The first buffer should contain at least 4 bytes for type
if len(buffers[0]) < 4 {
return true
}
// WireGuard packet type is a little-endian uint32 at start
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
// Check if packetType matches known WireGuard message types
if packetType == 4 && n > 32 {
return true
}
return false
}

View File

@@ -11,8 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
@@ -278,7 +276,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}

View File

@@ -16,8 +16,6 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -38,18 +36,16 @@ const (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
device *device.Device
deviceName string
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
device: device,
deviceName: deviceName,
}
wgCfg.startUAPI()
return wgCfg
@@ -91,19 +87,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return ipcErr
}
if endpoint != nil {
addr, err := netip.ParseAddr(endpoint.IP.String())
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
@@ -120,10 +104,7 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
c.activityRecorder.Remove(peerKey)
return ipcErr
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
@@ -224,10 +205,6 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
return parseStatus(c.deviceName, ipcStr)
}
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
return c.activityRecorder.GetLastActivities()
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error

View File

@@ -79,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -9,11 +9,11 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// FilterOutbound filter outgoing packets from host to external destinations
FilterOutbound(packetData []byte, size int) bool
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.FilterInbound(buf[offset:], len(buf)) {
if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()

View File

@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -8,7 +8,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/monotime"
)
type WGConfigurer interface {
@@ -20,5 +19,4 @@ type WGConfigurer interface {
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
)
const (
@@ -30,11 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
var (
// ErrIfaceNotFound is returned when the WireGuard interface is not found
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -123,9 +117,6 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
@@ -135,9 +126,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.RemovePeer(peerKey)
@@ -147,9 +135,6 @@ func (w *WGIface) RemovePeer(peerKey string) error {
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP)
@@ -159,9 +144,6 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
@@ -232,29 +214,10 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
// GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.GetStats()
}
func (w *WGIface) LastActivities() map[string]monotime.Time {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return nil
}
return w.configurer.LastActivities()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.FullStats()
}

View File

@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
}
// RemovePacketHook mocks base method.

View File

@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// SetNetwork mocks base method.

View File

@@ -41,12 +41,9 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
t.tundev = nsTunDev
var skipProxy bool
if val := os.Getenv(EnvSkipProxy); val != "" {
skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
if skipProxy {
return nsTunDev, tunNet, nil

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type ProxyBind struct {
@@ -29,17 +28,6 @@ type ProxyBind struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
}
// AddTurnConn adds a new connection to the bind.
@@ -66,10 +54,6 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
}
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
@@ -112,9 +96,6 @@ func (p *ProxyBind) close() error {
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
@@ -141,7 +122,6 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}

View File

@@ -11,8 +11,6 @@ import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -28,15 +26,6 @@ type ProxyWrapper struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -54,10 +43,6 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
@@ -92,8 +77,6 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
@@ -134,7 +117,6 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil {
return 0, ctx.Err()
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}

View File

@@ -36,8 +36,9 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
}
func (w *KernelFactory) Free() error {

View File

@@ -20,7 +20,9 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
}
func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind)
return &proxyBind.ProxyBind{
Bind: w.bind,
}
}
func (w *USPFactory) Free() error {

View File

@@ -1,19 +0,0 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@@ -12,5 +12,4 @@ type Proxy interface {
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
SetDisconnectListener(disconnected func())
}

View File

@@ -98,7 +98,9 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// WGUDPProxy proxies
@@ -29,8 +28,6 @@ type WGUDPProxy struct {
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -38,7 +35,6 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{
localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
}
return p
}
@@ -71,10 +67,6 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr
}
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() {
if p.remoteConn == nil {
@@ -119,8 +111,6 @@ func (p *WGUDPProxy) close() error {
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
@@ -151,7 +141,6 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err)
return
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
@@ -25,11 +26,11 @@ import (
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
enabledLocally bool
rosenpassEnabled bool
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
@@ -38,12 +39,12 @@ type ConnMgr struct {
lazyCtxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
rosenpassEnabled: engineConfig.RosenpassEnabled,
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
@@ -63,11 +64,6 @@ func (e *ConnMgr) Start(ctx context.Context) {
return
}
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
@@ -87,12 +83,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return nil
}
if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
return nil
}
log.Warnf("lazy connection manager is enabled by management feature flag")
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager()
@@ -142,7 +133,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
@@ -184,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
@@ -210,7 +201,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
if !ok {
return
}
defer conn.Close(false)
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
@@ -220,27 +211,23 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
}
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
// If locally the lazy connection is disabled, we force the peer connection open.
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
}
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
return conn, true
}
func (e *ConnMgr) Close() {
@@ -257,7 +244,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
@@ -288,7 +275,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {

View File

@@ -167,7 +167,6 @@ type BundleGenerator struct {
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
archive *zip.Writer
}
@@ -176,7 +175,6 @@ type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
type GeneratorDependencies struct {
@@ -187,12 +185,6 @@ type GeneratorDependencies struct {
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
// Default to 1 log file for backward compatibility when 0 is provided
logFileCount := cfg.LogFileCount
if logFileCount == 0 {
logFileCount = 1
}
return &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
@@ -204,7 +196,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
}
@@ -570,8 +561,32 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err)
}
// add rotated log files based on logFileCount
g.addRotatedLogFiles(logDir)
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
@@ -655,52 +670,6 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
return nil
}
// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount
func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
if g.logFileCount == 0 {
return
}
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
return
}
if len(files) == 0 {
return
}
// sort files by modification time (newest first)
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().After(fj.ModTime())
})
// include up to logFileCount rotated files
maxFiles := int(g.logFileCount)
if maxFiles > len(files) {
maxFiles = len(files)
}
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if err := g.addSingleLogFileGz(files[i], name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,

View File

@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())

View File

@@ -38,6 +38,7 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
@@ -174,7 +175,8 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
statusRecorder *peer.Status
statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager
routeManager routemanager.Manager
@@ -381,13 +383,7 @@ func (e *Engine) Start() error {
}
e.stateManager.Start()
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
@@ -404,7 +400,6 @@ func (e *Engine) Start() error {
InitialRoutes: initialRoutes,
StateManager: e.stateManager,
DNSServer: dnsServer,
DNSFeatureFlag: dnsFeatureFlag,
PeerStore: e.peerStore,
DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes,
@@ -456,7 +451,9 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
@@ -491,9 +488,9 @@ func (e *Engine) createFirewall() error {
}
func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err)
return fmt.Errorf("enable server router: %w", err)
}
if e.config.BlockLANAccess {
@@ -1012,6 +1009,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1022,7 +1021,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
@@ -1257,7 +1255,7 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close(false)
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
@@ -1304,12 +1302,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1332,16 +1331,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn, ok := e.peerStore.PeerConn(msg.Key)
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
msgType := msg.GetBody().GetType()
if msgType != sProto.Body_GO_IDLE {
e.connMgr.ActivatePeer(e.ctx, conn)
}
switch msg.GetBody().Type {
case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1398,8 +1392,6 @@ func (e *Engine) receiveSignalEvents() {
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE:
case sProto.Body_GO_IDLE:
e.connMgr.DeactivatePeer(conn)
}
return nil
@@ -1497,12 +1489,7 @@ func (e *Engine) close() {
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
if runtime.GOOS != "android" {
// nolint:nilnil
return nil, nil, false, nil
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx)
info.SetFlags(
e.config.RosenpassEnabled,
@@ -1519,12 +1506,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, false, err
return nil, nil, err
}
routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
dnsFeatureFlag := toDNSFeatureFlag(netMap)
return routes, &dnsCfg, dnsFeatureFlag, nil
return routes, &dnsCfg, nil
}
func (e *Engine) newWgIface() (*iface.WGIface, error) {
@@ -1572,14 +1558,18 @@ func (e *Engine) wgInterfaceCreate() (err error) {
return err
}
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
// due to tests where we are using a mocked version of the DNS server
if e.dnsServer != nil {
return e.dnsServer, nil
return nil, e.dnsServer, nil
}
switch runtime.GOOS {
case "android":
routes, dnsConfig, err := e.readInitialSettings()
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx,
e.wgInterface,
@@ -1590,19 +1580,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
e.config.DisableDNS,
)
go e.mobileDep.DnsReadyListener.OnReady()
return dnsServer, nil
return routes, dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
if err != nil {
return nil, err
return nil, nil, err
}
return dnsServer, nil
return nil, dnsServer, nil
}
}

View File

@@ -36,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -52,7 +53,6 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
@@ -97,7 +97,6 @@ type MockWGIface struct {
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
@@ -188,13 +187,6 @@ func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}
func (m *MockWGIface) LastActivities() map[string]monotime.Time {
if m.LastActivitiesFunc != nil {
return m.LastActivitiesFunc()
}
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
@@ -412,7 +404,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
type testCase struct {
@@ -801,7 +793,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() {
@@ -999,7 +991,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() {

View File

@@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
)
type wgIfaceBase interface {
@@ -39,5 +38,4 @@ type wgIfaceBase interface {
GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
}

View File

@@ -13,7 +13,7 @@ import (
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface WgInterface
wgIface lazyconn.WGIface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
@@ -22,7 +22,7 @@ type Listener struct {
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,
@@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Infof("exit from activity listener")
d.peerCfg.Log.Debugf("exit from activity listener")
} else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
}
@@ -59,11 +59,9 @@ func (d *Listener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
d.peerCfg.Log.Infof("activity detected")
break
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
@@ -73,7 +71,7 @@ func (d *Listener) ReadPackets() {
}
func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
if err := d.conn.Close(); err != nil {
@@ -83,6 +81,7 @@ func (d *Listener) Close() {
}
func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}

View File

@@ -1,27 +1,18 @@
package activity
import (
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface WgInterface
wgIface lazyconn.WGIface
peers map[peerid.ConnID]*Listener
done chan struct{}
@@ -29,7 +20,7 @@ type Manager struct {
mu sync.Mutex
}
func NewManager(wgIface WgInterface) *Manager {
func NewManager(wgIface lazyconn.WGIface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,

View File

@@ -0,0 +1,75 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@@ -0,0 +1,156 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -1,155 +0,0 @@
package inactivity
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
)
const (
checkInterval = 1 * time.Minute
DefaultInactivityThreshold = 15 * time.Minute
MinimumInactivityThreshold = 1 * time.Minute
)
type WgInterface interface {
LastActivities() map[string]monotime.Time
}
type Manager struct {
inactivePeersChan chan map[string]struct{}
iface WgInterface
interestedPeers map[string]*lazyconn.PeerConfig
inactivityThreshold time.Duration
}
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
if err != nil {
inactivityThreshold = DefaultInactivityThreshold
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
}
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
return &Manager{
inactivePeersChan: make(chan map[string]struct{}, 1),
iface: iface,
interestedPeers: make(map[string]*lazyconn.PeerConfig),
inactivityThreshold: inactivityThreshold,
}
}
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
if m == nil {
// return a nil channel that blocks forever
return nil
}
return m.inactivePeersChan
}
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
if m == nil {
return
}
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
return
}
peerCfg.Log.Infof("adding peer to inactivity manager")
m.interestedPeers[peerCfg.PublicKey] = peerCfg
}
func (m *Manager) RemovePeer(peer string) {
if m == nil {
return
}
pi, ok := m.interestedPeers[peer]
if !ok {
return
}
pi.Log.Debugf("remove peer from inactivity manager")
delete(m.interestedPeers, peer)
}
func (m *Manager) Start(ctx context.Context) {
if m == nil {
return
}
ticker := newTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C():
idlePeers, err := m.checkStats()
if err != nil {
log.Errorf("error checking stats: %v", err)
return
}
if len(idlePeers) == 0 {
continue
}
m.notifyInactivePeers(ctx, idlePeers)
}
}
}
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
select {
case m.inactivePeersChan <- inactivePeers:
case <-ctx.Done():
return
default:
return
}
}
func (m *Manager) checkStats() (map[string]struct{}, error) {
lastActivities := m.iface.LastActivities()
idlePeers := make(map[string]struct{})
checkTime := time.Now()
for peerID, peerCfg := range m.interestedPeers {
lastActive, ok := lastActivities[peerID]
if !ok {
// when peer is in connecting state
peerCfg.Log.Warnf("peer not found in wg stats")
continue
}
since := monotime.Since(lastActive)
if since > m.inactivityThreshold {
peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String())
idlePeers[peerID] = struct{}{}
}
}
return idlePeers, nil
}
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
if configuredThreshold == nil {
return DefaultInactivityThreshold, nil
}
if *configuredThreshold < MinimumInactivityThreshold {
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
}
return *configuredThreshold, nil
}

View File

@@ -1,114 +0,0 @@
package inactivity
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
)
type mockWgInterface struct {
lastActivities map[string]monotime.Time
}
func (m *mockWgInterface) LastActivities() map[string]monotime.Time {
return m.lastActivities
}
func TestPeerTriggersInactivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]monotime.Time{
peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case inactivePeers := <-manager.inactivePeersChan:
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
case <-time.After(1 * time.Second):
t.Fatal("expected inactivity event, but none received")
}
}
func TestPeerTriggersActivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]monotime.Time{
peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case <-manager.inactivePeersChan:
t.Fatal("expected inactive peer to be marked inactive")
case <-time.After(1 * time.Second):
// No inactivity event should be received
}
}
// fakeTickerMock implements Ticker interface for testing
type fakeTickerMock struct {
CChan chan time.Time
}
func (f *fakeTickerMock) C() <-chan time.Time {
return f.CChan
}
func (f *fakeTickerMock) Stop() {}

View File

@@ -1,24 +0,0 @@
package inactivity
import "time"
var newTicker = func(d time.Duration) Ticker {
return &realTicker{t: time.NewTicker(d)}
}
type Ticker interface {
C() <-chan time.Time
Stop()
}
type realTicker struct {
t *time.Ticker
}
func (r *realTicker) C() <-chan time.Time {
return r.t.C
}
func (r *realTicker) Stop() {
r.t.Stop()
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
@@ -42,46 +43,60 @@ type Config struct {
type Manager struct {
engineCtx context.Context
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityManager *inactivity.Manager
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
// Route HA group management
// If any peer in the same HA group is active, all peers in that group should prevent going idle
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex
onInactive chan peerid.ConnID
}
// NewManager creates a new lazy connection manager
// engineCtx is the context for creating peer Connection
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
engineCtx: engineCtx,
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
peerToHAGroups: make(map[string][]route.HAUniqueID),
haGroupToPeers: make(map[route.HAUniqueID][]string),
onInactive: make(chan peerid.ConnID),
}
if wgIface.IsUserspaceBind() {
m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
} else {
log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
@@ -116,28 +131,24 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
}
}
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
len(m.haGroupToPeers), len(m.peerToHAGroups))
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
if m.inactivityManager != nil {
go m.inactivityManager.Start(ctx)
}
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(peerConnID)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(ctx, peerConnID)
}
}
}
// ExcludePeer marks peers for a permanent connection
@@ -145,7 +156,7 @@ func (m *Manager) Start(ctx context.Context) {
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -176,7 +187,7 @@ func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(&peerCfg); err != nil {
if err := m.addActivePeer(ctx, peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
@@ -186,7 +197,7 @@ func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
return added
}
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -206,6 +217,9 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
@@ -215,7 +229,7 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(peerCfg)
m.activateNewPeerInActiveGroup(ctx, peerCfg)
}
return false, nil
@@ -223,7 +237,7 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -233,7 +247,7 @@ func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
continue
}
if err := m.addActivePeer(&cfg); err != nil {
if err := m.addActivePeer(ctx, cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
@@ -250,7 +264,7 @@ func (m *Manager) RemovePeer(peerID string) {
// ActivatePeer activates a peer connection when a signal message is received
// Also activates all peers in the same HA groups as this peer
func (m *Manager) ActivatePeer(peerID string) (found bool) {
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, mp := m.getPeerForActivation(peerID)
@@ -258,43 +272,15 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) {
return false
}
cfg.Log.Infof("activate peer from inactive state by remote signal message")
if !m.activateSinglePeer(cfg, mp) {
if !m.activateSinglePeer(ctx, cfg, mp) {
return false
}
m.activateHAGroupPeers(cfg)
m.activateHAGroupPeers(ctx, peerID)
return true
}
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
// Returns nil values if the peer should be skipped
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
@@ -316,36 +302,41 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
return cfg, mp
}
// activateSinglePeer activates a single peer
// return true if the peer was activated, false if it was already active
func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
if mp.expectedWatcher == watcherInactivity {
// activateSinglePeer activates a single peer (internal method)
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false
}
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
m.inactivityManager.AddPeer(cfg)
cfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
var peersToActivate []string
m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
haGroups := m.peerToHAGroups[triggerPeerID]
if len(haGroups) == 0 {
m.routesMu.RUnlock()
triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, peerID := range peers {
if peerID != triggeredPeerCfg.PublicKey {
if peerID != triggerPeerID {
peersToActivate = append(peersToActivate, peerID)
}
}
@@ -359,16 +350,16 @@ func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
continue
}
if m.activateSinglePeer(cfg, mp) {
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
}
if activatedCount > 0 {
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
activatedCount, triggeredPeerCfg.PublicKey, haGroups)
activatedCount, triggerPeerID, haGroups)
}
}
@@ -403,13 +394,13 @@ func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool)
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(&peerCfg, mp) {
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
return
}
@@ -417,19 +408,23 @@ func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
m.managedPeers[peerCfg.PublicKey] = peerCfg
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: peerCfg,
peerCfg: &peerCfg,
expectedWatcher: watcherInactivity,
}
m.inactivityManager.AddPeer(peerCfg)
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
return nil
}
@@ -441,7 +436,12 @@ func (m *Manager) removePeer(peerID string) {
cfg.Log.Infof("removing lazy peer")
m.inactivityManager.RemovePeer(cfg.PublicKey)
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
@@ -451,8 +451,12 @@ func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
@@ -466,7 +470,7 @@ func (m *Manager) close() {
}
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
@@ -476,45 +480,38 @@ func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID
}
for _, haGroup := range haGroups {
if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
return true
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
}
}
return false
}
func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// If any peer in the group is active, do defer idle
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
return true
}
}
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -531,56 +528,100 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
mp.peerCfg.Log.Infof("detected peer activity")
if !m.activateSinglePeer(mp.peerCfg, mp) {
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
return
}
m.activateHAGroupPeers(mp.peerCfg)
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
for peerID := range peerIDs {
peerCfg, ok := m.managedPeers[peerID]
if !ok {
log.Errorf("peer not found by peerId: %v", peerID)
continue
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
}
return
}
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
continue
}
mp.peerCfg.Log.Infof("connection timed out")
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
continue
}
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
continue
}
mp.peerCfg.Log.Infof("start activity monitor")
mp.peerCfg.Log.Infof("connection timed out")
mp.expectedWatcher = watcherActivity
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
continue
}
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@@ -6,13 +6,9 @@ import (
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
)
type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
LastActivities() map[string]monotime.Time
}

View File

@@ -148,7 +148,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
return nil, err
}

View File

@@ -117,9 +117,10 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher
wg sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -135,17 +136,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
peerConnDispatcher: services.PeerConnDispatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
}
return conn, nil
@@ -167,7 +169,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -224,7 +226,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
}
// Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close(signalToRemote bool) {
func (conn *Conn) Close() {
conn.mu.Lock()
defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock()
@@ -234,12 +236,6 @@ func (conn *Conn) Close(signalToRemote bool) {
return
}
if signalToRemote {
if err := conn.signaler.SignalIdle(conn.config.Key); err != nil {
conn.Log.Errorf("failed to signal idle state to peer: %v", err)
}
}
conn.Log.Infof("close peer connection")
conn.ctxCancel()
@@ -408,10 +404,15 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
}
func (conn *Conn) onICEStateDisconnected() {
@@ -449,6 +450,7 @@ func (conn *Conn) onICEStateDisconnected() {
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
}
changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -489,8 +491,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
conn.dumpState.NewLocalProxy()
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
@@ -530,6 +530,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
func (conn *Conn) onRelayDisconnected() {
@@ -544,7 +545,11 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
}
if conn.wgProxyRelay != nil {

View File

@@ -68,13 +68,3 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
return nil
}
func (s *Signaler) SignalIdle(remoteKey string) error {
return s.signal.Send(&sProto.Message{
Key: s.wgPrivateKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &sProto.Body{
Type: sProto.Body_GO_IDLE,
},
})
}

View File

@@ -19,7 +19,6 @@ type RelayConnInfo struct {
}
type WorkerRelay struct {
peerCtx context.Context
log *log.Entry
isController bool
config ConnConfig
@@ -34,9 +33,8 @@ type WorkerRelay struct {
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,
isController: ctrl,
config: config,
@@ -64,7 +62,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection")

View File

@@ -95,17 +95,6 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
}
func (s *Store) PeerConnIdle(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
p.Close(true)
}
func (s *Store) PeerConnClose(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
@@ -114,7 +103,7 @@ func (s *Store) PeerConnClose(pubKey string) {
if !ok {
return
}
p.Close(false)
p.Close()
}
func (s *Store) PeersPubKey() []string {

View File

@@ -10,10 +10,11 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
@@ -552,16 +553,41 @@ func (w *Watcher) Stop() {
w.currentChosenStatus = nil
}
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
switch handlerType(params.Route, params.UseNewDNSRoute) {
func HandlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDnsInterceptor:
return dnsinterceptor.New(params)
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
wgInterface,
peerStore,
)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(params.WgInterface)
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
return dynamic.NewRoute(params, dnsAddr)
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
default:
return static.NewRoute(params)
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
}

View File

@@ -7,12 +7,12 @@ import (
"time"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct {
name string
statuses map[route.ID]routerPeerStatus
@@ -811,12 +811,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
}
// create new clientNetwork
client := &Watcher{
handler: static.NewRoute(params),
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
}

View File

@@ -1,28 +0,0 @@
package common
import (
"time"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type HandlerParams struct {
Route *route.Route
RouteRefCounter *refcounter.RouteRefCounter
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
DnsRouterInterval time.Duration
StatusRecorder *peer.Status
WgInterface iface.WGIface
DnsServer dns.Server
PeerStore *peerstore.Store
UseNewDNSRoute bool
Firewall manager.Manager
FakeIPManager *fakeip.Manager
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
@@ -13,14 +12,11 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
@@ -28,11 +24,6 @@ import (
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
RemoveInternalDNATMapping(netip.Addr) error
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
@@ -49,22 +40,26 @@ type DnsInterceptor struct {
interceptedDomains domainMap
wgInterface wgInterface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
}
func New(params common.HandlerParams) *DnsInterceptor {
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
wgInterface wgInterface,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
statusRecorder: params.StatusRecorder,
dnsServer: params.DnsServer,
wgInterface: params.WgInterface,
peerStore: params.PeerStore,
firewall: params.Firewall,
fakeIPManager: params.FakeIPManager,
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
wgInterface: wgInterface,
interceptedDomains: make(domainMap),
peerStore: peerStore,
}
}
@@ -83,13 +78,9 @@ func (d *DnsInterceptor) RemoveRoute() error {
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// Routes should use fake IPs
routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
}
// AllowedIPs should use real IPs
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
@@ -97,10 +88,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
}
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
}
d.cleanupDNATMappings()
for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
@@ -113,68 +102,6 @@ func (d *DnsInterceptor) RemoveRoute() error {
return nberrors.FormatErrorOrNil(merr)
}
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
return realPrefix
}
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
}
return realPrefix
}
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
// AllowedIPs always use real IPs
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
if err != nil {
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
}
if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
realPrefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
routePrefix := d.transformRealToFakePrefix(realPrefix)
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
}
// Add to AllowedIPs if we have a current peer (uses real IPs)
if d.currentPeerKey == "" {
return nil
}
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
}
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
if d.currentPeerKey == "" {
return nil
}
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
}
return nil
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -182,9 +109,14 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// AllowedIPs use real IPs
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
merr = multierror.Append(merr, err)
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
@@ -200,7 +132,6 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
@@ -356,8 +287,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes)
}
}
@@ -368,22 +297,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
return nil
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -392,163 +305,70 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
var dnatMappings map[netip.Addr]netip.Addr
// Handle DNAT mappings for new prefixes
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
dnatMappings = make(map[netip.Addr]netip.Addr)
for _, prefix := range toAdd {
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
// Add new prefixes
for _, prefix := range toAdd {
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
merr = multierror.Append(merr, err)
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
resolvedDomain.SafeString(),
ref.Out,
)
}
}
d.addDNATMappings(dnatMappings)
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
// Routes use fake IPs
routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
// AllowedIPs use real IPs
if err := d.removeAllowedIP(prefix); err != nil {
merr = multierror.Append(merr, err)
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
d.removeDNATMappings(toRemove)
}
// Update domain prefixes using resolved domain as key - store real IPs
// Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 {
if d.route.KeepRoute {
// replace stored prefixes with old + added
// nolint:gocritic
newPrefixes = append(oldPrefixes, toAdd...)
}
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
if len(realPrefixes) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for _, prefix := range realPrefixes {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
}
}
}
}
// internalDnatFw checks if the firewall supports internal DNAT
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
if d.firewall == nil || runtime.GOOS != "android" {
return nil, false
}
fw, ok := d.firewall.(internalDNATer)
return fw, ok
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
if len(mappings) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
// cleanupDNATMappings removes all DNAT mappings for this interceptor
func (d *DnsInterceptor) cleanupDNATMappings() {
if _, ok := d.internalDnatFw(); !ok {
return
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes)
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
if _, ok := d.internalDnatFw(); !ok {
return
}
// Replace A and AAAA records with fake IPs
for _, answer := range reply.Answer {
switch rr := answer.(type) {
case *dns.A:
realIP, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
realIP, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {

View File

@@ -14,7 +14,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
@@ -53,16 +52,24 @@ type Route struct {
resolverAddr string
}
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
resolverAddr string,
) *Route {
return &Route{
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
interval: params.DnsRouterInterval,
statusRecorder: params.StatusRecorder,
wgInterface: params.WgInterface,
resolverAddr: resolverAddr,
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
}
}

View File

@@ -1,93 +0,0 @@
package fakeip
import (
"fmt"
"net/netip"
"sync"
)
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
type Manager struct {
mu sync.Mutex
nextIP netip.Addr // Next IP to allocate
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
baseIP netip.Addr // First usable IP: 240.0.0.1
maxIP netip.Addr // Last usable IP: 240.255.255.254
}
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
func NewManager() *Manager {
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
return &Manager{
nextIP: baseIP,
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: baseIP,
maxIP: maxIP,
}
}
// AllocateFakeIP allocates a fake IP for the given real IP
// Returns the fake IP, or existing fake IP if already allocated
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
if !realIP.Is4() {
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
}
m.mu.Lock()
defer m.mu.Unlock()
if fakeIP, exists := m.allocated[realIP]; exists {
return fakeIP, nil
}
startIP := m.nextIP
for {
currentIP := m.nextIP
// Advance to next IP, wrapping at boundary
if m.nextIP.Compare(m.maxIP) >= 0 {
m.nextIP = m.baseIP
} else {
m.nextIP = m.nextIP.Next()
}
// Check if current IP is available
if _, inUse := m.fakeToReal[currentIP]; !inUse {
m.allocated[realIP] = currentIP
m.fakeToReal[currentIP] = realIP
return currentIP, nil
}
// Prevent infinite loop if all IPs exhausted
if m.nextIP.Compare(startIP) == 0 {
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
}
}
}
// GetFakeIP returns the fake IP for a real IP if it exists
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
m.mu.Lock()
defer m.mu.Unlock()
fakeIP, exists := m.allocated[realIP]
return fakeIP, exists
}
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
m.mu.Lock()
defer m.mu.Unlock()
realIP, exists := m.fakeToReal[fakeIP]
return realIP, exists
}
// GetFakeIPBlock returns the fake IP block used by this manager
func (m *Manager) GetFakeIPBlock() netip.Prefix {
return netip.MustParsePrefix("240.0.0.0/8")
}

View File

@@ -1,240 +0,0 @@
package fakeip
import (
"net/netip"
"sync"
"testing"
)
func TestNewManager(t *testing.T) {
manager := NewManager()
if manager.baseIP.String() != "240.0.0.1" {
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
}
if manager.maxIP.String() != "240.255.255.254" {
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
}
if manager.nextIP.Compare(manager.baseIP) != 0 {
t.Errorf("Expected nextIP to start at baseIP")
}
}
func TestAllocateFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("8.8.8.8")
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
if !fakeIP.Is4() {
t.Error("Fake IP should be IPv4")
}
// Check it's in the correct range
if fakeIP.As4()[0] != 240 {
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
}
// Should return same fake IP for same real IP
fakeIP2, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to get existing fake IP: %v", err)
}
if fakeIP.Compare(fakeIP2) != 0 {
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
}
}
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
manager := NewManager()
realIPv6 := netip.MustParseAddr("2001:db8::1")
_, err := manager.AllocateFakeIP(realIPv6)
if err == nil {
t.Error("Expected error for IPv6 address")
}
}
func TestGetFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("1.1.1.1")
// Should not exist initially
_, exists := manager.GetFakeIP(realIP)
if exists {
t.Error("Fake IP should not exist before allocation")
}
// Allocate and check
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate: %v", err)
}
fakeIP, exists := manager.GetFakeIP(realIP)
if !exists {
t.Error("Fake IP should exist after allocation")
}
if fakeIP.Compare(expectedFakeIP) != 0 {
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
}
}
func TestMultipleAllocations(t *testing.T) {
manager := NewManager()
allocations := make(map[netip.Addr]netip.Addr)
// Allocate multiple IPs
for i := 1; i <= 100; i++ {
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
}
// Check for duplicates
for _, existingFake := range allocations {
if fakeIP.Compare(existingFake) == 0 {
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
}
}
allocations[realIP] = fakeIP
}
// Verify all allocations can be retrieved
for realIP, expectedFake := range allocations {
actualFake, exists := manager.GetFakeIP(realIP)
if !exists {
t.Errorf("Missing allocation for %s", realIP.String())
}
if actualFake.Compare(expectedFake) != 0 {
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
}
}
}
func TestGetFakeIPBlock(t *testing.T) {
manager := NewManager()
block := manager.GetFakeIPBlock()
expected := "240.0.0.0/8"
if block.String() != expected {
t.Errorf("Expected %s, got %s", expected, block.String())
}
}
func TestConcurrentAccess(t *testing.T) {
manager := NewManager()
const numGoroutines = 50
const allocationsPerGoroutine = 10
var wg sync.WaitGroup
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
// Concurrent allocations
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for j := 0; j < allocationsPerGoroutine; j++ {
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
return
}
results <- fakeIP
}
}(i)
}
wg.Wait()
close(results)
// Check for duplicates
seen := make(map[netip.Addr]bool)
count := 0
for fakeIP := range results {
if seen[fakeIP] {
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
}
seen[fakeIP] = true
count++
}
if count != numGoroutines*allocationsPerGoroutine {
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
}
}
func TestIPExhaustion(t *testing.T) {
// Create a manager with limited range for testing
manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
}
// Allocate all available IPs
realIPs := []netip.Addr{
netip.MustParseAddr("1.0.0.1"),
netip.MustParseAddr("1.0.0.2"),
netip.MustParseAddr("1.0.0.3"),
}
for _, realIP := range realIPs {
_, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
}
// Try to allocate one more - should fail
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
if err == nil {
t.Error("Expected exhaustion error")
}
}
func TestWrapAround(t *testing.T) {
// Create manager starting near the end of range
manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
}
// Allocate the last IP
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
if err != nil {
t.Fatalf("Failed to allocate first IP: %v", err)
}
if fakeIP1.String() != "240.0.0.254" {
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
}
// Next allocation should wrap around to the beginning
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
if err != nil {
t.Fatalf("Failed to allocate second IP: %v", err)
}
if fakeIP2.String() != "240.0.0.1" {
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
}
}

View File

@@ -8,11 +8,9 @@ import (
"net/netip"
"net/url"
"runtime"
"slices"
"sync"
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -26,8 +24,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@@ -53,7 +49,7 @@ type Manager interface {
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
SetFirewall(firewall.Manager) error
EnableServerRouter(firewall firewall.Manager) error
Stop(stateManager *statemanager.Manager)
}
@@ -67,7 +63,6 @@ type ManagerConfig struct {
InitialRoutes []*route.Route
StateManager *statemanager.Manager
DNSServer dns.Server
DNSFeatureFlag bool
PeerStore *peerstore.Store
DisableClientRoutes bool
DisableServerRoutes bool
@@ -94,13 +89,11 @@ type DefaultManager struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server
firewall firewall.Manager
peerStore *peerstore.Store
useNewDNSRoute bool
disableClientRoutes bool
disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager
}
func NewManager(config ManagerConfig) *DefaultManager {
@@ -136,31 +129,11 @@ func NewManager(config ManagerConfig) *DefaultManager {
}
if runtime.GOOS == "android" {
dm.setupAndroidRoutes(config)
cr := dm.initialClientRoutes(config.InitialRoutes)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
cr := m.initialClientRoutes(config.InitialRoutes)
routesForComparison := slices.Clone(cr)
if config.DNSFeatureFlag {
m.fakeIPManager = fakeip.NewManager()
id := uuid.NewString()
fakeIPRoute := &route.Route{
ID: route.ID(id),
Network: m.fakeIPManager.GetFakeIPBlock(),
NetID: route.NetID(id),
Peer: m.pubKey,
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.routeRefCounter = refcounter.New(
@@ -249,16 +222,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
return routeselector.NewRouteSelector()
}
// SetFirewall sets the firewall manager for the DefaultManager
// Not thread-safe, should be called before starting the manager
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
m.firewall = firewall
if m.disableServerRoutes || firewall == nil {
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
if m.disableServerRoutes {
log.Info("server routes are disabled")
return nil
}
if firewall == nil {
return errors.New("firewall manager is not set")
}
var err error
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil {
@@ -326,20 +299,17 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
for id, route := range toAdd {
params := common.HandlerParams{
Route: route,
RouteRefCounter: m.routeRefCounter,
AllowedIPsRefCounter: m.allowedIPsRefCounter,
DnsRouterInterval: m.dnsRouteInterval,
StatusRecorder: m.statusRecorder,
WgInterface: m.wgInterface,
DnsServer: m.dnsServer,
PeerStore: m.peerStore,
UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall,
FakeIPManager: m.fakeIPManager,
}
handler := client.HandlerFromRoute(params)
handler := client.HandlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
@@ -547,7 +517,6 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}

View File

@@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
}
func (m *MockManager) SetFirewall(firewall.Manager) error {
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
panic("implement me")
}

View File

@@ -0,0 +1,124 @@
package notifier
import (
"net/netip"
"runtime"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
initialRouteRanges []string
routeRanges []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String())
}
sort.Strings(nets)
n.initialRouteRanges = nets
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
var newNets []string
for _, routes := range idMap {
for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String())
}
}
sort.Strings(newNets)
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
n.routeRanges = newNets
n.notify()
}
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
if !n.hasDiff(n.routeRanges, newNets) {
return
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
}(n.listener)
}
func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
for i, v := range a {
if v != b[i] {
return true
}
}
return false
}
func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
func addIPv6RangeIfNeeded(inputRanges []string) []string {
ranges := inputRanges
for _, r := range inputRanges {
// we are intentionally adding the ipv6 default range in case of ipv4 default range
// to ensure that all traffic is managed by the tunnel interface on android
if r == "0.0.0.0/0" {
ranges = append(ranges, "::/0")
break
}
}
return ranges
}

View File

@@ -1,127 +0,0 @@
//go:build android
package notifier
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
// initialRoutes contains fake IP block for interface configuration
filteredInitial := make([]*route.Route, 0)
for _, r := range initialRoutes {
if r.IsDynamic() {
continue
}
filteredInitial = append(filteredInitial, r)
}
n.initialRoutes = filteredInitial
// routesForComparison excludes fake IP block for comparison with new routes
filteredComparison := make([]*route.Route, 0)
for _, r := range routesForComparison {
if r.IsDynamic() {
continue
}
filteredComparison = append(filteredComparison, r)
}
n.currentRoutes = filteredComparison
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
for _, r := range routes {
if r.IsDynamic() {
continue
}
newRoutes = append(newRoutes, r)
}
}
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
return
}
n.currentRoutes = newRoutes
n.notify()
}
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
// Not used on Android
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
routeStrings := n.routesToStrings(n.currentRoutes)
sort.Strings(routeStrings)
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
}(n.listener)
}
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
nets := make([]string, 0, len(routes))
for _, r := range routes {
nets = append(nets, r.NetString())
}
return nets
}
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
slices.SortFunc(a, func(x, y *route.Route) int {
return strings.Compare(x.NetString(), y.NetString())
})
slices.SortFunc(b, func(x, y *route.Route) int {
return strings.Compare(x.NetString(), y.NetString())
})
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
return x.NetString() == y.NetString()
})
}
func (n *Notifier) GetInitialRouteRanges() []string {
initialStrings := n.routesToStrings(n.initialRoutes)
sort.Strings(initialStrings)
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
}
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
for _, r := range routes {
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
return append(slices.Clone(inputRanges), "::/0")
}
}
return inputRanges
}

View File

@@ -1,80 +0,0 @@
//go:build ios
package notifier
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
currentPrefixes []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
if slices.Equal(n.currentPrefixes, newNets) {
return
}
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
}(n.listener)
}
func (n *Notifier) GetInitialRouteRanges() []string {
return nil
}
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
for _, r := range inputRanges {
if r == "0.0.0.0/0" {
return append(slices.Clone(inputRanges), "::/0")
}
}
return inputRanges
}

View File

@@ -1,36 +0,0 @@
//go:build !android && !ios
package notifier
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct{}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
// Not used on non-mobile platforms
}
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
@@ -17,11 +16,11 @@ type Route struct {
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(params common.HandlerParams) *Route {
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/netip"
"sync"
"sync/atomic"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -53,9 +52,6 @@ type SysOps struct {
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
// seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems
seq atomic.Uint32
}
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
@@ -65,11 +61,6 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
}
}
//nolint:unused // only used on BSD systems
func (r *SysOps) getSeq() int {
return int(r.seq.Add(1))
}
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()

View File

@@ -108,7 +108,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1

View File

@@ -2290,7 +2290,6 @@ type DebugBundleRequest struct {
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2353,13 +2352,6 @@ func (x *DebugBundleRequest) GetUploadURL() string {
return ""
}
func (x *DebugBundleRequest) GetLogFileCount() uint32 {
if x != nil {
return x.LogFileCount
}
return 0
}
type DebugBundleResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
@@ -3503,6 +3495,246 @@ func (x *GetEventsResponse) GetEvents() []*SystemEvent {
return nil
}
type SetConfigRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
RosenpassEnabled *bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"`
RosenpassPermissive *bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"`
ServerSSHAllowed *bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"`
DisableAutoConnect *bool `protobuf:"varint,4,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"`
NetworkMonitor *bool `protobuf:"varint,5,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"`
DnsRouteInterval *durationpb.Duration `protobuf:"bytes,6,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"`
DisableClientRoutes *bool `protobuf:"varint,7,opt,name=disable_client_routes,json=disableClientRoutes,proto3,oneof" json:"disable_client_routes,omitempty"`
DisableServerRoutes *bool `protobuf:"varint,8,opt,name=disable_server_routes,json=disableServerRoutes,proto3,oneof" json:"disable_server_routes,omitempty"`
DisableDns *bool `protobuf:"varint,9,opt,name=disable_dns,json=disableDns,proto3,oneof" json:"disable_dns,omitempty"`
DisableFirewall *bool `protobuf:"varint,10,opt,name=disable_firewall,json=disableFirewall,proto3,oneof" json:"disable_firewall,omitempty"`
BlockLanAccess *bool `protobuf:"varint,11,opt,name=block_lan_access,json=blockLanAccess,proto3,oneof" json:"block_lan_access,omitempty"`
LazyConnectionEnabled *bool `protobuf:"varint,12,opt,name=lazyConnectionEnabled,proto3,oneof" json:"lazyConnectionEnabled,omitempty"`
BlockInbound *bool `protobuf:"varint,13,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"`
InterfaceName *string `protobuf:"bytes,14,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"`
WireguardPort *int64 `protobuf:"varint,15,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"`
CustomDNSAddress *string `protobuf:"bytes,16,opt,name=customDNSAddress,proto3,oneof" json:"customDNSAddress,omitempty"`
ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
DnsLabels []string `protobuf:"bytes,18,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"`
CleanDNSLabels *bool `protobuf:"varint,19,opt,name=cleanDNSLabels,proto3,oneof" json:"cleanDNSLabels,omitempty"`
NatExternalIPs []string `protobuf:"bytes,20,rep,name=natExternalIPs,proto3" json:"natExternalIPs,omitempty"`
CleanNATExternalIPs *bool `protobuf:"varint,21,opt,name=cleanNATExternalIPs,proto3,oneof" json:"cleanNATExternalIPs,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *SetConfigRequest) Reset() {
*x = SetConfigRequest{}
mi := &file_daemon_proto_msgTypes[52]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *SetConfigRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetConfigRequest) ProtoMessage() {}
func (x *SetConfigRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[52]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead.
func (*SetConfigRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{52}
}
func (x *SetConfigRequest) GetRosenpassEnabled() bool {
if x != nil && x.RosenpassEnabled != nil {
return *x.RosenpassEnabled
}
return false
}
func (x *SetConfigRequest) GetRosenpassPermissive() bool {
if x != nil && x.RosenpassPermissive != nil {
return *x.RosenpassPermissive
}
return false
}
func (x *SetConfigRequest) GetServerSSHAllowed() bool {
if x != nil && x.ServerSSHAllowed != nil {
return *x.ServerSSHAllowed
}
return false
}
func (x *SetConfigRequest) GetDisableAutoConnect() bool {
if x != nil && x.DisableAutoConnect != nil {
return *x.DisableAutoConnect
}
return false
}
func (x *SetConfigRequest) GetNetworkMonitor() bool {
if x != nil && x.NetworkMonitor != nil {
return *x.NetworkMonitor
}
return false
}
func (x *SetConfigRequest) GetDnsRouteInterval() *durationpb.Duration {
if x != nil {
return x.DnsRouteInterval
}
return nil
}
func (x *SetConfigRequest) GetDisableClientRoutes() bool {
if x != nil && x.DisableClientRoutes != nil {
return *x.DisableClientRoutes
}
return false
}
func (x *SetConfigRequest) GetDisableServerRoutes() bool {
if x != nil && x.DisableServerRoutes != nil {
return *x.DisableServerRoutes
}
return false
}
func (x *SetConfigRequest) GetDisableDns() bool {
if x != nil && x.DisableDns != nil {
return *x.DisableDns
}
return false
}
func (x *SetConfigRequest) GetDisableFirewall() bool {
if x != nil && x.DisableFirewall != nil {
return *x.DisableFirewall
}
return false
}
func (x *SetConfigRequest) GetBlockLanAccess() bool {
if x != nil && x.BlockLanAccess != nil {
return *x.BlockLanAccess
}
return false
}
func (x *SetConfigRequest) GetLazyConnectionEnabled() bool {
if x != nil && x.LazyConnectionEnabled != nil {
return *x.LazyConnectionEnabled
}
return false
}
func (x *SetConfigRequest) GetBlockInbound() bool {
if x != nil && x.BlockInbound != nil {
return *x.BlockInbound
}
return false
}
func (x *SetConfigRequest) GetInterfaceName() string {
if x != nil && x.InterfaceName != nil {
return *x.InterfaceName
}
return ""
}
func (x *SetConfigRequest) GetWireguardPort() int64 {
if x != nil && x.WireguardPort != nil {
return *x.WireguardPort
}
return 0
}
func (x *SetConfigRequest) GetCustomDNSAddress() string {
if x != nil && x.CustomDNSAddress != nil {
return *x.CustomDNSAddress
}
return ""
}
func (x *SetConfigRequest) GetExtraIFaceBlacklist() []string {
if x != nil {
return x.ExtraIFaceBlacklist
}
return nil
}
func (x *SetConfigRequest) GetDnsLabels() []string {
if x != nil {
return x.DnsLabels
}
return nil
}
func (x *SetConfigRequest) GetCleanDNSLabels() bool {
if x != nil && x.CleanDNSLabels != nil {
return *x.CleanDNSLabels
}
return false
}
func (x *SetConfigRequest) GetNatExternalIPs() []string {
if x != nil {
return x.NatExternalIPs
}
return nil
}
func (x *SetConfigRequest) GetCleanNATExternalIPs() bool {
if x != nil && x.CleanNATExternalIPs != nil {
return *x.CleanNATExternalIPs
}
return false
}
type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *SetConfigResponse) Reset() {
*x = SetConfigResponse{}
mi := &file_daemon_proto_msgTypes[53]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *SetConfigResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetConfigResponse) ProtoMessage() {}
func (x *SetConfigResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[53]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead.
func (*SetConfigResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{53}
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -3513,7 +3745,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[53]
mi := &file_daemon_proto_msgTypes[55]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -3525,7 +3757,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[53]
mi := &file_daemon_proto_msgTypes[55]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -3754,15 +3986,14 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x88\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\"}\n" +
"\x13DebugBundleResponse\x12\x12\n" +
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +
@@ -3853,7 +4084,52 @@ const file_daemon_proto_rawDesc = "" +
"\x06SYSTEM\x10\x04\"\x12\n" +
"\x10GetEventsRequest\"@\n" +
"\x11GetEventsResponse\x12+\n" +
"\x06events\x18\x01 \x03(\v2\x13.daemon.SystemEventR\x06events*b\n" +
"\x06events\x18\x01 \x03(\v2\x13.daemon.SystemEventR\x06events\"\x98\v\n" +
"\x10SetConfigRequest\x12/\n" +
"\x10rosenpassEnabled\x18\x01 \x01(\bH\x00R\x10rosenpassEnabled\x88\x01\x01\x125\n" +
"\x13rosenpassPermissive\x18\x02 \x01(\bH\x01R\x13rosenpassPermissive\x88\x01\x01\x12/\n" +
"\x10serverSSHAllowed\x18\x03 \x01(\bH\x02R\x10serverSSHAllowed\x88\x01\x01\x123\n" +
"\x12disableAutoConnect\x18\x04 \x01(\bH\x03R\x12disableAutoConnect\x88\x01\x01\x12+\n" +
"\x0enetworkMonitor\x18\x05 \x01(\bH\x04R\x0enetworkMonitor\x88\x01\x01\x12J\n" +
"\x10dnsRouteInterval\x18\x06 \x01(\v2\x19.google.protobuf.DurationH\x05R\x10dnsRouteInterval\x88\x01\x01\x127\n" +
"\x15disable_client_routes\x18\a \x01(\bH\x06R\x13disableClientRoutes\x88\x01\x01\x127\n" +
"\x15disable_server_routes\x18\b \x01(\bH\aR\x13disableServerRoutes\x88\x01\x01\x12$\n" +
"\vdisable_dns\x18\t \x01(\bH\bR\n" +
"disableDns\x88\x01\x01\x12.\n" +
"\x10disable_firewall\x18\n" +
" \x01(\bH\tR\x0fdisableFirewall\x88\x01\x01\x12-\n" +
"\x10block_lan_access\x18\v \x01(\bH\n" +
"R\x0eblockLanAccess\x88\x01\x01\x129\n" +
"\x15lazyConnectionEnabled\x18\f \x01(\bH\vR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" +
"\rblock_inbound\x18\r \x01(\bH\fR\fblockInbound\x88\x01\x01\x12)\n" +
"\rinterfaceName\x18\x0e \x01(\tH\rR\rinterfaceName\x88\x01\x01\x12)\n" +
"\rwireguardPort\x18\x0f \x01(\x03H\x0eR\rwireguardPort\x88\x01\x01\x12/\n" +
"\x10customDNSAddress\x18\x10 \x01(\tH\x0fR\x10customDNSAddress\x88\x01\x01\x120\n" +
"\x13extraIFaceBlacklist\x18\x11 \x03(\tR\x13extraIFaceBlacklist\x12\x1d\n" +
"\n" +
"dns_labels\x18\x12 \x03(\tR\tdnsLabels\x12+\n" +
"\x0ecleanDNSLabels\x18\x13 \x01(\bH\x10R\x0ecleanDNSLabels\x88\x01\x01\x12&\n" +
"\x0enatExternalIPs\x18\x14 \x03(\tR\x0enatExternalIPs\x125\n" +
"\x13cleanNATExternalIPs\x18\x15 \x01(\bH\x11R\x13cleanNATExternalIPs\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x16\n" +
"\x14_rosenpassPermissiveB\x13\n" +
"\x11_serverSSHAllowedB\x15\n" +
"\x13_disableAutoConnectB\x11\n" +
"\x0f_networkMonitorB\x13\n" +
"\x11_dnsRouteIntervalB\x18\n" +
"\x16_disable_client_routesB\x18\n" +
"\x16_disable_server_routesB\x0e\n" +
"\f_disable_dnsB\x13\n" +
"\x11_disable_firewallB\x13\n" +
"\x11_block_lan_accessB\x18\n" +
"\x16_lazyConnectionEnabledB\x10\n" +
"\x0e_block_inboundB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x13\n" +
"\x11_customDNSAddressB\x11\n" +
"\x0f_cleanDNSLabelsB\x16\n" +
"\x14_cleanNATExternalIPs\"\x13\n" +
"\x11SetConfigResponse*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -3862,7 +4138,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xb3\v\n" +
"\x05TRACE\x10\a2\xf7\v\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -3885,7 +4161,8 @@ const file_daemon_proto_rawDesc = "" +
"\x18SetNetworkMapPersistence\x12'.daemon.SetNetworkMapPersistenceRequest\x1a(.daemon.SetNetworkMapPersistenceResponse\"\x00\x12H\n" +
"\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" +
"\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" +
"\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00B\bZ\x06/protob\x06proto3"
"\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12B\n" +
"\tSetConfig\x12\x18.daemon.SetConfigRequest\x1a\x19.daemon.SetConfigResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -3900,7 +4177,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 55)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 57)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
@@ -3957,18 +4234,20 @@ var file_daemon_proto_goTypes = []any{
(*SystemEvent)(nil), // 52: daemon.SystemEvent
(*GetEventsRequest)(nil), // 53: daemon.GetEventsRequest
(*GetEventsResponse)(nil), // 54: daemon.GetEventsResponse
nil, // 55: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 56: daemon.PortInfo.Range
nil, // 57: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 58: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 59: google.protobuf.Timestamp
(*SetConfigRequest)(nil), // 55: daemon.SetConfigRequest
(*SetConfigResponse)(nil), // 56: daemon.SetConfigResponse
nil, // 57: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 58: daemon.PortInfo.Range
nil, // 59: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 60: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 61: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
58, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
60, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
59, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
59, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
58, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
61, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
61, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
60, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -3977,8 +4256,8 @@ var file_daemon_proto_depIdxs = []int32{
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
55, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
56, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
57, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
58, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -3989,55 +4268,58 @@ var file_daemon_proto_depIdxs = []int32{
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
59, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
57, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
61, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
59, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
27, // 28: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 29: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
6, // 30: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
8, // 31: daemon.DaemonService.Up:input_type -> daemon.UpRequest
10, // 32: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
12, // 33: daemon.DaemonService.Down:input_type -> daemon.DownRequest
14, // 34: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
23, // 35: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
25, // 36: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
25, // 37: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
3, // 38: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
32, // 39: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
34, // 40: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
36, // 41: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
39, // 42: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
41, // 43: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
43, // 44: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
45, // 45: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
48, // 46: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
51, // 47: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
53, // 48: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
5, // 49: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 50: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 51: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 52: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 53: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 54: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 55: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 56: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 57: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 58: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 59: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 60: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 61: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 62: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 63: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 64: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 65: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
50, // 66: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 67: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 68: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
49, // [49:69] is the sub-list for method output_type
29, // [29:49] is the sub-list for method input_type
29, // [29:29] is the sub-list for extension type_name
29, // [29:29] is the sub-list for extension extendee
0, // [0:29] is the sub-list for field type_name
60, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 29: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 30: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
6, // 31: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
8, // 32: daemon.DaemonService.Up:input_type -> daemon.UpRequest
10, // 33: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
12, // 34: daemon.DaemonService.Down:input_type -> daemon.DownRequest
14, // 35: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
23, // 36: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
25, // 37: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
25, // 38: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
3, // 39: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
32, // 40: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
34, // 41: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
36, // 42: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
39, // 43: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
41, // 44: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
43, // 45: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
45, // 46: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
48, // 47: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
51, // 48: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
53, // 49: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
55, // 50: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
5, // 51: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 52: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 53: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 54: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 55: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 56: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 57: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 58: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 59: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 60: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 61: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 62: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 63: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 64: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 65: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 66: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 67: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
50, // 68: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 69: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 70: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 71: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
51, // [51:72] is the sub-list for method output_type
30, // [30:51] is the sub-list for method input_type
30, // [30:30] is the sub-list for extension type_name
30, // [30:30] is the sub-list for extension extendee
0, // [0:30] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -4052,13 +4334,14 @@ func file_daemon_proto_init() {
}
file_daemon_proto_msgTypes[45].OneofWrappers = []any{}
file_daemon_proto_msgTypes[46].OneofWrappers = []any{}
file_daemon_proto_msgTypes[52].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 3,
NumMessages: 55,
NumMessages: 57,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -67,6 +67,9 @@ service DaemonService {
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
// SetConfig updates daemon configuration without reconnecting
rpc SetConfig(SetConfigRequest) returns (SetConfigResponse) {}
}
@@ -356,7 +359,6 @@ message DebugBundleRequest {
string status = 2;
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
}
message DebugBundleResponse {
@@ -497,3 +499,29 @@ message GetEventsRequest {}
message GetEventsResponse {
repeated SystemEvent events = 1;
}
message SetConfigRequest {
optional bool rosenpassEnabled = 1;
optional bool rosenpassPermissive = 2;
optional bool serverSSHAllowed = 3;
optional bool disableAutoConnect = 4;
optional bool networkMonitor = 5;
optional google.protobuf.Duration dnsRouteInterval = 6;
optional bool disable_client_routes = 7;
optional bool disable_server_routes = 8;
optional bool disable_dns = 9;
optional bool disable_firewall = 10;
optional bool block_lan_access = 11;
optional bool lazyConnectionEnabled = 12;
optional bool block_inbound = 13;
optional string interfaceName = 14;
optional int64 wireguardPort = 15;
optional string customDNSAddress = 16;
repeated string extraIFaceBlacklist = 17;
repeated string dns_labels = 18;
optional bool cleanDNSLabels = 19;
repeated string natExternalIPs = 20;
optional bool cleanNATExternalIPs = 21;
}
message SetConfigResponse {}

View File

@@ -55,6 +55,8 @@ type DaemonServiceClient interface {
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
// SetConfig updates daemon configuration without reconnecting
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
}
type daemonServiceClient struct {
@@ -268,6 +270,15 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques
return out, nil
}
func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) {
out := new(SetConfigResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetConfig", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -309,6 +320,8 @@ type DaemonServiceServer interface {
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
// SetConfig updates daemon configuration without reconnecting
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -376,6 +389,9 @@ func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, Daemo
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
}
func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetConfig not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -752,6 +768,24 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetConfig_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetConfigRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetConfig(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetConfig",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetConfig(ctx, req.(*SetConfigRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -835,6 +869,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetEvents",
Handler: _DaemonService_GetEvents_Handler,
},
{
MethodName: "SetConfig",
Handler: _DaemonService_SetConfig_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -11,7 +11,7 @@ fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"

View File

@@ -42,7 +42,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
Anonymize: req.GetAnonymize(),
ClientStatus: req.GetStatus(),
IncludeSystemInfo: req.GetSystemInfo(),
LogFileCount: req.GetLogFileCount(),
},
)

View File

@@ -799,6 +799,133 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
}, nil
}
// SetConfig updates daemon configuration without reconnecting
func (s *Server) SetConfig(ctx context.Context, req *proto.SetConfigRequest) (*proto.SetConfigResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "daemon is not configured")
}
configChanged := false
if req.RosenpassEnabled != nil && s.config.RosenpassEnabled != *req.RosenpassEnabled {
s.config.RosenpassEnabled = *req.RosenpassEnabled
configChanged = true
}
if req.RosenpassPermissive != nil && s.config.RosenpassPermissive != *req.RosenpassPermissive {
s.config.RosenpassPermissive = *req.RosenpassPermissive
configChanged = true
}
if req.ServerSSHAllowed != nil && s.config.ServerSSHAllowed != nil && *s.config.ServerSSHAllowed != *req.ServerSSHAllowed {
*s.config.ServerSSHAllowed = *req.ServerSSHAllowed
configChanged = true
}
if req.DisableAutoConnect != nil && s.config.DisableAutoConnect != *req.DisableAutoConnect {
s.config.DisableAutoConnect = *req.DisableAutoConnect
configChanged = true
}
if req.NetworkMonitor != nil && s.config.NetworkMonitor != nil && *s.config.NetworkMonitor != *req.NetworkMonitor {
*s.config.NetworkMonitor = *req.NetworkMonitor
configChanged = true
}
if req.DnsRouteInterval != nil {
duration := req.DnsRouteInterval.AsDuration()
if s.config.DNSRouteInterval != duration {
s.config.DNSRouteInterval = duration
configChanged = true
}
}
if req.DisableClientRoutes != nil && s.config.DisableClientRoutes != *req.DisableClientRoutes {
s.config.DisableClientRoutes = *req.DisableClientRoutes
configChanged = true
}
if req.DisableServerRoutes != nil && s.config.DisableServerRoutes != *req.DisableServerRoutes {
s.config.DisableServerRoutes = *req.DisableServerRoutes
configChanged = true
}
if req.DisableDns != nil && s.config.DisableDNS != *req.DisableDns {
s.config.DisableDNS = *req.DisableDns
configChanged = true
}
if req.DisableFirewall != nil && s.config.DisableFirewall != *req.DisableFirewall {
s.config.DisableFirewall = *req.DisableFirewall
configChanged = true
}
if req.BlockLanAccess != nil && s.config.BlockLANAccess != *req.BlockLanAccess {
s.config.BlockLANAccess = *req.BlockLanAccess
configChanged = true
}
if req.LazyConnectionEnabled != nil && s.config.LazyConnectionEnabled != *req.LazyConnectionEnabled {
s.config.LazyConnectionEnabled = *req.LazyConnectionEnabled
configChanged = true
}
if req.BlockInbound != nil && s.config.BlockInbound != *req.BlockInbound {
s.config.BlockInbound = *req.BlockInbound
configChanged = true
}
if req.InterfaceName != nil && s.config.WgIface != *req.InterfaceName {
s.config.WgIface = *req.InterfaceName
configChanged = true
}
if req.WireguardPort != nil && s.config.WgPort != int(*req.WireguardPort) {
s.config.WgPort = int(*req.WireguardPort)
configChanged = true
}
if req.CustomDNSAddress != nil && s.config.CustomDNSAddress != *req.CustomDNSAddress {
s.config.CustomDNSAddress = *req.CustomDNSAddress
configChanged = true
}
if len(req.ExtraIFaceBlacklist) > 0 {
s.config.IFaceBlackList = req.ExtraIFaceBlacklist
configChanged = true
}
if len(req.DnsLabels) > 0 || (req.CleanDNSLabels != nil && *req.CleanDNSLabels) {
if req.CleanDNSLabels != nil && *req.CleanDNSLabels {
s.config.DNSLabels = domain.List{}
} else {
var err error
s.config.DNSLabels, err = domain.FromStringList(req.DnsLabels)
if err != nil {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid DNS labels: %v", err)
}
}
configChanged = true
}
if len(req.NatExternalIPs) > 0 || (req.CleanNATExternalIPs != nil && *req.CleanNATExternalIPs) {
s.config.NATExternalIPs = req.NatExternalIPs
configChanged = true
}
if configChanged {
if err := internal.WriteOutConfig(s.latestConfigInput.ConfigPath, s.config); err != nil {
return nil, gstatus.Errorf(codes.Internal, "write config: %v", err)
}
log.Debug("Configuration updated successfully")
}
return &proto.SetConfigResponse{}, nil
}
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()

2
go.mod
View File

@@ -25,7 +25,7 @@ require (
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.36.6
google.golang.org/protobuf v1.36.5
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)

4
go.sum
View File

@@ -1164,8 +1164,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -87,12 +87,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.
EXPECT().

View File

@@ -16,7 +16,7 @@ type AccountsAPI struct {
// List list all accounts, only returns one account always
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil)
if err != nil {
return nil, err
}
@@ -34,7 +34,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
// Delete delete account
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
if err != nil {
return err
}

View File

@@ -117,7 +117,7 @@ func (c *Client) initialize() {
}
// NewRequest creates and executes new management API request
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader, query map[string]string) (*http.Response, error) {
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
if err != nil {
return nil, err
@@ -129,14 +129,6 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re
req.Header.Add("Content-Type", "application/json")
}
if len(query) != 0 {
q := req.URL.Query()
for k, v := range query {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err

View File

@@ -16,7 +16,7 @@ type DNSAPI struct {
// ListNameserverGroups list all nameserver groups
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
// GetNameserverGroup get nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
// DeleteNameserverGroup delete nameserver group
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
if err != nil {
return err
}
@@ -94,7 +94,7 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
// GetSettings get DNS settings
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil)
if err != nil {
return nil, err
}
@@ -112,7 +112,7 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,7 @@ type EventsAPI struct {
// List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil)
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,7 @@ type GeoLocationAPI struct {
// ListCountries list all country codes
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil)
if err != nil {
return nil, err
}
@@ -28,7 +28,7 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
// ListCountryCities Get a list of all English city names for a given country code
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
if err != nil {
return nil, err
}

View File

@@ -16,7 +16,7 @@ type GroupsAPI struct {
// List list all groups
// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
// Get get group info
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
// Delete delete group
// See more: https://docs.netbird.io/api/resources/groups#delete-a-group
func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type NetworksAPI struct {
// List list all networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-networks
func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
// Get get network info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network
func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
// Delete delete network
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network
func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil)
if err != nil {
return err
}
@@ -108,7 +108,7 @@ func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI {
// List list all resources in networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources
func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil)
if err != nil {
return nil, err
}
@@ -122,7 +122,7 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource,
// Get get network resource info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource
func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
if err != nil {
return nil, err
}
@@ -140,7 +140,7 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -158,7 +158,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -172,7 +172,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
// Delete delete network resource
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource
func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
if err != nil {
return err
}
@@ -200,7 +200,7 @@ func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI {
// List list all routers in networks
// See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers
func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil)
if err != nil {
return nil, err
}
@@ -214,7 +214,7 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro
// Get get network router info
// See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router
func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
if err != nil {
return nil, err
}
@@ -232,7 +232,7 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -250,7 +250,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -264,7 +264,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
// Delete delete network router
// See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router
func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
if err != nil {
return err
}

View File

@@ -13,30 +13,10 @@ type PeersAPI struct {
c *Client
}
// PeersListOption options for Peers List API
type PeersListOption func() (string, string)
func PeerNameFilter(name string) PeersListOption {
return func() (string, string) {
return "name", name
}
}
func PeerIPFilter(ip string) PeersListOption {
return func() (string, string) {
return "ip", ip
}
}
// List list all peers
// See more: https://docs.netbird.io/api/resources/peers#list-all-peers
func (a *PeersAPI) List(ctx context.Context, opts ...PeersListOption) ([]api.Peer, error) {
query := make(map[string]string)
for _, o := range opts {
k, v := o()
query[k] = v
}
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil, query)
func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil)
if err != nil {
return nil, err
}
@@ -50,7 +30,7 @@ func (a *PeersAPI) List(ctx context.Context, opts ...PeersListOption) ([]api.Pee
// Get retrieve a peer
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer
func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil)
if err != nil {
return nil, err
}
@@ -68,7 +48,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -82,7 +62,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
// Delete delete a peer
// See more: https://docs.netbird.io/api/resources/peers#delete-a-peer
func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil)
if err != nil {
return err
}
@@ -96,7 +76,7 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
// ListAccessiblePeers list all peers that the specified peer can connect to within the network
// See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers
func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil)
if err != nil {
return nil, err
}

View File

@@ -184,10 +184,6 @@ func TestPeers_Integration(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, peers)
filteredPeers, err := c.Peers.List(context.Background(), rest.PeerIPFilter("192.168.10.0"))
require.NoError(t, err)
require.Empty(t, filteredPeers)
peer, err := c.Peers.Get(context.Background(), peers[0].Id)
require.NoError(t, err)
assert.Equal(t, peers[0].Id, peer.Id)

View File

@@ -18,7 +18,7 @@ type PoliciesAPI struct {
func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
path := "/api/policies"
resp, err := a.c.NewRequest(ctx, "GET", path, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
@@ -32,7 +32,7 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
// Get get policy info
// See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy
func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil)
if err != nil {
return nil, err
}
@@ -50,7 +50,7 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -70,7 +70,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes), nil)
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -84,7 +84,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P
// Delete delete policy
// See more: https://docs.netbird.io/api/resources/policies#delete-a-policy
func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil)
if err != nil {
return err
}

Some files were not shown because too many files have changed in this diff Show More