mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 15:14:03 -04:00
Compare commits
1 Commits
nb-interfa
...
set-comman
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
355bab9bb4 |
4
.github/workflows/git-town.yml
vendored
4
.github/workflows/git-town.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: git-town/action@v1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
skip-single-stacks: true
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
|
||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.21"
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -231,17 +231,3 @@ jobs:
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||
|
||||
post_on_forum:
|
||||
runs-on: ubuntu-latest
|
||||
continue-on-error: true
|
||||
needs: [trigger_signer]
|
||||
steps:
|
||||
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
|
||||
48
README.md
48
README.md
@@ -14,9 +14,6 @@
|
||||
<br>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
@@ -32,13 +29,13 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||
New: NetBird Kubernetes Operator
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -50,9 +47,10 @@
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
### Open-Source Network Security in a Single Platform
|
||||
|
||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
||||
|
||||

|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
@@ -136,3 +134,37 @@ We use open-source technologies like [WireGuard®](https://www.wireguard.com/),
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
## Configuration Management
|
||||
|
||||
Netbird now supports direct configuration management via CLI commands:
|
||||
|
||||
- You can use `netbird set` as a regular user if the daemon is running; it will securely update the config via the daemon.
|
||||
- If the daemon is not running, you need write access to the config file (typically requires root).
|
||||
|
||||
### Set a configuration value
|
||||
|
||||
```
|
||||
netbird set <setting> <value>
|
||||
# or using environment variables
|
||||
NB_INTERFACE_NAME=utun5 netbird set interface-name
|
||||
```
|
||||
|
||||
### Get a configuration value
|
||||
|
||||
```
|
||||
netbird get <setting>
|
||||
# or using environment variables
|
||||
NB_INTERFACE_NAME=utun5 netbird get interface-name
|
||||
```
|
||||
|
||||
### Show all configuration values
|
||||
|
||||
```
|
||||
netbird show
|
||||
```
|
||||
|
||||
- All settings support environment variable overrides: `NB_<SETTING>` or `WT_<SETTING>` (e.g. `NB_ENABLE_ROSENPASS=true`).
|
||||
- Supported settings: management-url, admin-url, interface-name, external-ip-map, extra-iface-blacklist, dns-resolver-address, extra-dns-labels, preshared-key, enable-rosenpass, rosenpass-permissive, allow-server-ssh, network-monitor, disable-auto-connect, disable-client-routes, disable-server-routes, disable-dns, disable-firewall, block-lan-access, block-inbound, enable-lazy-connection, wireguard-port, dns-router-interval.
|
||||
|
||||
See `netbird set --help`, `netbird get --help`, and `netbird show --help` for more details.
|
||||
|
||||
|
||||
@@ -64,9 +64,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
execWorkaround(androidSDKVersion)
|
||||
|
||||
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
@@ -205,10 +203,8 @@ func (c *Client) Networks() *NetworkArray {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
netStr := r.Network.String()
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
if routes[0].IsDynamic() {
|
||||
continue
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
@@ -218,7 +214,7 @@ func (c *Client) Networks() *NetworkArray {
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Network: routes[0].Network.String(),
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
|
||||
// In Android version 11 and earlier, pidfd-related system calls
|
||||
// are not allowed by the seccomp policy, which causes crashes due
|
||||
// to SIGSYS signals.
|
||||
|
||||
//go:linkname checkPidfdOnce os.checkPidfdOnce
|
||||
var checkPidfdOnce func() error
|
||||
|
||||
func execWorkaround(androidSDKVersion int) {
|
||||
if androidSDKVersion > 30 { // above Android 11
|
||||
return
|
||||
}
|
||||
|
||||
checkPidfdOnce = func() error {
|
||||
return fmt.Errorf("unsupported Android version")
|
||||
}
|
||||
}
|
||||
@@ -17,18 +17,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
|
||||
var (
|
||||
logFileCount uint32
|
||||
systemInfoFlag bool
|
||||
uploadBundleFlag bool
|
||||
uploadBundleURLFlag string
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
@@ -96,13 +88,12 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
}
|
||||
if uploadBundleFlag {
|
||||
request.UploadURL = uploadBundleURLFlag
|
||||
if debugUploadBundle {
|
||||
request.UploadURL = debugUploadBundleURL
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
if err != nil {
|
||||
@@ -114,7 +105,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if uploadBundleFlag {
|
||||
if debugUploadBundle {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
@@ -232,13 +223,12 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
}
|
||||
if uploadBundleFlag {
|
||||
request.UploadURL = uploadBundleURLFlag
|
||||
if debugUploadBundle {
|
||||
request.UploadURL = debugUploadBundleURL
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
if err != nil {
|
||||
@@ -265,7 +255,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if uploadBundleFlag {
|
||||
if debugUploadBundle {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
@@ -307,7 +297,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
@@ -385,15 +375,3 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect
|
||||
}
|
||||
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
||||
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
|
||||
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,7 +39,10 @@ const (
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -71,7 +76,10 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
debugUploadBundle bool
|
||||
debugUploadBundleURL string
|
||||
lazyConnEnabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
@@ -80,6 +88,30 @@ var (
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
getCmd = &cobra.Command{
|
||||
Use: "get <setting>",
|
||||
Short: "Get a configuration value from the config file",
|
||||
Long: `Get a configuration value from the Netbird config file. You can also use NB_<SETTING> or WT_<SETTING> environment variables to override the value (same as 'set').`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: getFunc,
|
||||
}
|
||||
|
||||
showCmd = &cobra.Command{
|
||||
Use: "show",
|
||||
Short: "Show all configuration values",
|
||||
Long: `Show all configuration values from the Netbird config file, with environment variable overrides if present.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: showFunc,
|
||||
}
|
||||
|
||||
reloadCmd = &cobra.Command{
|
||||
Use: "reload",
|
||||
Short: "Reload the configuration in the daemon (daemon mode)",
|
||||
Long: `Reload the configuration from disk in the running daemon. Use after 'set' to apply changes without restarting the service.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: reloadFunc,
|
||||
}
|
||||
)
|
||||
|
||||
// Execute executes the root command.
|
||||
@@ -145,6 +177,9 @@ func init() {
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(getCmd)
|
||||
rootCmd.AddCommand(showCmd)
|
||||
rootCmd.AddCommand(reloadCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
@@ -177,8 +212,11 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -398,3 +436,167 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func getFunc(cmd *cobra.Command, args []string) error {
|
||||
setting := args[0]
|
||||
upper := strings.ToUpper(strings.ReplaceAll(setting, "-", "_"))
|
||||
if v, ok := os.LookupEnv("NB_" + upper); ok {
|
||||
cmd.Println(v)
|
||||
return nil
|
||||
} else if v, ok := os.LookupEnv("WT_" + upper); ok {
|
||||
cmd.Println(v)
|
||||
return nil
|
||||
}
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config: %v", err)
|
||||
}
|
||||
switch setting {
|
||||
case "management-url":
|
||||
cmd.Println(config.ManagementURL.String())
|
||||
case "admin-url":
|
||||
cmd.Println(config.AdminURL.String())
|
||||
case "interface-name":
|
||||
cmd.Println(config.WgIface)
|
||||
case "external-ip-map":
|
||||
cmd.Println(strings.Join(config.NATExternalIPs, ","))
|
||||
case "extra-iface-blacklist":
|
||||
cmd.Println(strings.Join(config.IFaceBlackList, ","))
|
||||
case "dns-resolver-address":
|
||||
cmd.Println(config.CustomDNSAddress)
|
||||
case "extra-dns-labels":
|
||||
cmd.Println(config.DNSLabels.SafeString())
|
||||
case "preshared-key":
|
||||
cmd.Println(config.PreSharedKey)
|
||||
case "enable-rosenpass":
|
||||
cmd.Println(config.RosenpassEnabled)
|
||||
case "rosenpass-permissive":
|
||||
cmd.Println(config.RosenpassPermissive)
|
||||
case "allow-server-ssh":
|
||||
if config.ServerSSHAllowed != nil {
|
||||
cmd.Println(*config.ServerSSHAllowed)
|
||||
} else {
|
||||
cmd.Println(false)
|
||||
}
|
||||
case "network-monitor":
|
||||
if config.NetworkMonitor != nil {
|
||||
cmd.Println(*config.NetworkMonitor)
|
||||
} else {
|
||||
cmd.Println(false)
|
||||
}
|
||||
case "disable-auto-connect":
|
||||
cmd.Println(config.DisableAutoConnect)
|
||||
case "disable-client-routes":
|
||||
cmd.Println(config.DisableClientRoutes)
|
||||
case "disable-server-routes":
|
||||
cmd.Println(config.DisableServerRoutes)
|
||||
case "disable-dns":
|
||||
cmd.Println(config.DisableDNS)
|
||||
case "disable-firewall":
|
||||
cmd.Println(config.DisableFirewall)
|
||||
case "block-lan-access":
|
||||
cmd.Println(config.BlockLANAccess)
|
||||
case "block-inbound":
|
||||
cmd.Println(config.BlockInbound)
|
||||
case "enable-lazy-connection":
|
||||
cmd.Println(config.LazyConnectionEnabled)
|
||||
case "wireguard-port":
|
||||
cmd.Println(config.WgPort)
|
||||
case "dns-router-interval":
|
||||
cmd.Println(config.DNSRouteInterval)
|
||||
default:
|
||||
return fmt.Errorf("unknown setting: %s", setting)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func showFunc(cmd *cobra.Command, args []string) error {
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config: %v", err)
|
||||
}
|
||||
settings := []string{
|
||||
"management-url", "admin-url", "interface-name", "external-ip-map", "extra-iface-blacklist", "dns-resolver-address", "extra-dns-labels", "preshared-key", "enable-rosenpass", "rosenpass-permissive", "allow-server-ssh", "network-monitor", "disable-auto-connect", "disable-client-routes", "disable-server-routes", "disable-dns", "disable-firewall", "block-lan-access", "block-inbound", "enable-lazy-connection", "wireguard-port", "dns-router-interval",
|
||||
}
|
||||
for _, setting := range settings {
|
||||
upper := strings.ToUpper(strings.ReplaceAll(setting, "-", "_"))
|
||||
var val string
|
||||
if v, ok := os.LookupEnv("NB_" + upper); ok {
|
||||
val = v + " (from NB_ env)"
|
||||
} else if v, ok := os.LookupEnv("WT_" + upper); ok {
|
||||
val = v + " (from WT_ env)"
|
||||
} else {
|
||||
switch setting {
|
||||
case "management-url":
|
||||
val = config.ManagementURL.String()
|
||||
case "admin-url":
|
||||
val = config.AdminURL.String()
|
||||
case "interface-name":
|
||||
val = config.WgIface
|
||||
case "external-ip-map":
|
||||
val = strings.Join(config.NATExternalIPs, ",")
|
||||
case "extra-iface-blacklist":
|
||||
val = strings.Join(config.IFaceBlackList, ",")
|
||||
case "dns-resolver-address":
|
||||
val = config.CustomDNSAddress
|
||||
case "extra-dns-labels":
|
||||
val = config.DNSLabels.SafeString()
|
||||
case "preshared-key":
|
||||
val = config.PreSharedKey
|
||||
case "enable-rosenpass":
|
||||
val = fmt.Sprintf("%v", config.RosenpassEnabled)
|
||||
case "rosenpass-permissive":
|
||||
val = fmt.Sprintf("%v", config.RosenpassPermissive)
|
||||
case "allow-server-ssh":
|
||||
if config.ServerSSHAllowed != nil {
|
||||
val = fmt.Sprintf("%v", *config.ServerSSHAllowed)
|
||||
} else {
|
||||
val = "false"
|
||||
}
|
||||
case "network-monitor":
|
||||
if config.NetworkMonitor != nil {
|
||||
val = fmt.Sprintf("%v", *config.NetworkMonitor)
|
||||
} else {
|
||||
val = "false"
|
||||
}
|
||||
case "disable-auto-connect":
|
||||
val = fmt.Sprintf("%v", config.DisableAutoConnect)
|
||||
case "disable-client-routes":
|
||||
val = fmt.Sprintf("%v", config.DisableClientRoutes)
|
||||
case "disable-server-routes":
|
||||
val = fmt.Sprintf("%v", config.DisableServerRoutes)
|
||||
case "disable-dns":
|
||||
val = fmt.Sprintf("%v", config.DisableDNS)
|
||||
case "disable-firewall":
|
||||
val = fmt.Sprintf("%v", config.DisableFirewall)
|
||||
case "block-lan-access":
|
||||
val = fmt.Sprintf("%v", config.BlockLANAccess)
|
||||
case "block-inbound":
|
||||
val = fmt.Sprintf("%v", config.BlockInbound)
|
||||
case "enable-lazy-connection":
|
||||
val = fmt.Sprintf("%v", config.LazyConnectionEnabled)
|
||||
case "wireguard-port":
|
||||
val = fmt.Sprintf("%d", config.WgPort)
|
||||
case "dns-router-interval":
|
||||
val = config.DNSRouteInterval.String()
|
||||
}
|
||||
}
|
||||
cmd.Printf("%-22s: %s\n", setting, val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadFunc(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
_, err = client.ReloadConfig(cmd.Context(), &proto.ReloadConfigRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reload config in daemon: %v", err)
|
||||
}
|
||||
cmd.Println("Configuration reloaded in daemon.")
|
||||
return nil
|
||||
}
|
||||
|
||||
475
client/cmd/set.go
Normal file
475
client/cmd/set.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
osuser "os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var setCmd = &cobra.Command{
|
||||
Use: "set <setting> <value>",
|
||||
Short: "Set a configuration value without running up",
|
||||
Long: `Set a configuration value in the Netbird config file without running 'up'.
|
||||
|
||||
You can also set values via environment variables NB_<SETTING> or WT_<SETTING> (e.g. NB_INTERFACE_NAME=utun5 netbird set interface-name).
|
||||
|
||||
Supported settings:
|
||||
management-url (string) e.g. https://api.netbird.io:443
|
||||
admin-url (string) e.g. https://app.netbird.io:443
|
||||
interface-name (string) e.g. utun5
|
||||
external-ip-map (list) comma-separated, e.g. 12.34.56.78,12.34.56.79/eth0
|
||||
extra-iface-blacklist (list) comma-separated, e.g. eth1,eth2
|
||||
dns-resolver-address (string) e.g. 127.0.0.1:5053
|
||||
extra-dns-labels (list) comma-separated, e.g. vpc1,mgmt1
|
||||
preshared-key (string)
|
||||
enable-rosenpass (bool) true/false
|
||||
rosenpass-permissive (bool) true/false
|
||||
allow-server-ssh (bool) true/false
|
||||
network-monitor (bool) true/false
|
||||
disable-auto-connect (bool) true/false
|
||||
disable-client-routes (bool) true/false
|
||||
disable-server-routes (bool) true/false
|
||||
disable-dns (bool) true/false
|
||||
disable-firewall (bool) true/false
|
||||
block-lan-access (bool) true/false
|
||||
block-inbound (bool) true/false
|
||||
enable-lazy-connection (bool) true/false
|
||||
wireguard-port (int) e.g. 51820
|
||||
dns-router-interval (duration) e.g. 1m, 30s
|
||||
|
||||
Examples:
|
||||
NB_INTERFACE_NAME=utun5 netbird set interface-name
|
||||
netbird set wireguard-port 51820
|
||||
netbird set external-ip-map 12.34.56.78,12.34.56.79/eth0
|
||||
netbird set enable-rosenpass true
|
||||
netbird set dns-router-interval 2m
|
||||
netbird set extra-dns-labels vpc1,mgmt1
|
||||
netbird set disable-firewall true
|
||||
`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: setFunc,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(setCmd)
|
||||
}
|
||||
|
||||
func setFunc(cmd *cobra.Command, args []string) error {
|
||||
setting := args[0]
|
||||
var value string
|
||||
|
||||
// Check environment variables first
|
||||
upper := strings.ToUpper(strings.ReplaceAll(setting, "-", "_"))
|
||||
if v, ok := os.LookupEnv("NB_" + upper); ok {
|
||||
value = v
|
||||
} else if v, ok := os.LookupEnv("WT_" + upper); ok {
|
||||
value = v
|
||||
} else {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("missing value for setting %s", setting)
|
||||
}
|
||||
value = args[1]
|
||||
}
|
||||
|
||||
// If not root, try to use the daemon (only if cmd is not nil)
|
||||
if cmd != nil {
|
||||
if u, err := osuser.Current(); err == nil && u.Uid != "0" {
|
||||
conn, err := getClient(cmd)
|
||||
if err == nil {
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
_, err = client.SetConfigValue(cmd.Context(), &proto.SetConfigValueRequest{Setting: setting, Value: value})
|
||||
if err == nil {
|
||||
if cmd != nil {
|
||||
cmd.Println("Configuration updated via daemon.")
|
||||
} else {
|
||||
fmt.Println("Configuration updated via daemon.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return fmt.Errorf("daemon error: %v", s.Message())
|
||||
}
|
||||
return fmt.Errorf("failed to update config via daemon: %v", err)
|
||||
}
|
||||
// else: fall back to direct file write
|
||||
}
|
||||
}
|
||||
|
||||
switch setting {
|
||||
case "management-url":
|
||||
input := internal.ConfigInput{ConfigPath: configPath, ManagementURL: value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set management-url: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set management-url to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set management-url to: %s\n", value)
|
||||
}
|
||||
case "admin-url":
|
||||
input := internal.ConfigInput{ConfigPath: configPath, AdminURL: value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set admin-url: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set admin-url to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set admin-url to: %s\n", value)
|
||||
}
|
||||
case "interface-name":
|
||||
if err := parseInterfaceName(value); err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, InterfaceName: &value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set interface-name: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set interface-name to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set interface-name to: %s\n", value)
|
||||
}
|
||||
case "external-ip-map":
|
||||
var ips []string
|
||||
if value == "" {
|
||||
ips = []string{}
|
||||
} else {
|
||||
ips = strings.Split(value, ",")
|
||||
}
|
||||
if err := validateNATExternalIPs(ips); err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, NATExternalIPs: ips}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set external-ip-map: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set external-ip-map to: %v\n", ips)
|
||||
} else {
|
||||
fmt.Printf("Set external-ip-map to: %v\n", ips)
|
||||
}
|
||||
case "extra-iface-blacklist":
|
||||
var ifaces []string
|
||||
if value == "" {
|
||||
ifaces = []string{}
|
||||
} else {
|
||||
ifaces = strings.Split(value, ",")
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, ExtraIFaceBlackList: ifaces}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set extra-iface-blacklist: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set extra-iface-blacklist to: %v\n", ifaces)
|
||||
} else {
|
||||
fmt.Printf("Set extra-iface-blacklist to: %v\n", ifaces)
|
||||
}
|
||||
case "dns-resolver-address":
|
||||
if value != "" && !isValidAddrPort(value) {
|
||||
return fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", value)
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, CustomDNSAddress: []byte(value)}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set dns-resolver-address: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set dns-resolver-address to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set dns-resolver-address to: %s\n", value)
|
||||
}
|
||||
case "extra-dns-labels":
|
||||
var labels []string
|
||||
if value == "" {
|
||||
labels = []string{}
|
||||
} else {
|
||||
labels = strings.Split(value, ",")
|
||||
}
|
||||
domains, err := domain.ValidateDomains(labels)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid DNS labels: %v", err)
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DNSLabels: domains}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set extra-dns-labels: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set extra-dns-labels to: %v\n", labels)
|
||||
} else {
|
||||
fmt.Printf("Set extra-dns-labels to: %v\n", labels)
|
||||
}
|
||||
case "preshared-key":
|
||||
input := internal.ConfigInput{ConfigPath: configPath, PreSharedKey: &value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set preshared-key: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set preshared-key to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set preshared-key to: %s\n", value)
|
||||
}
|
||||
case "hostname":
|
||||
// Hostname is not persisted in config, so just print a warning
|
||||
if cmd != nil {
|
||||
cmd.Printf("Warning: hostname is not persisted in config. Use --hostname with up command.\n")
|
||||
} else {
|
||||
fmt.Printf("Warning: hostname is not persisted in config. Use --hostname with up command.\n")
|
||||
}
|
||||
case "enable-rosenpass":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, RosenpassEnabled: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set enable-rosenpass: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set enable-rosenpass to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set enable-rosenpass to: %v\n", b)
|
||||
}
|
||||
case "rosenpass-permissive":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, RosenpassPermissive: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set rosenpass-permissive: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set rosenpass-permissive to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set rosenpass-permissive to: %v\n", b)
|
||||
}
|
||||
case "allow-server-ssh":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, ServerSSHAllowed: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set allow-server-ssh: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set allow-server-ssh to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set allow-server-ssh to: %v\n", b)
|
||||
}
|
||||
case "network-monitor":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, NetworkMonitor: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set network-monitor: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set network-monitor to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set network-monitor to: %v\n", b)
|
||||
}
|
||||
case "disable-auto-connect":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableAutoConnect: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-auto-connect: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-auto-connect to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-auto-connect to: %v\n", b)
|
||||
}
|
||||
case "disable-client-routes":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableClientRoutes: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-client-routes: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-client-routes to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-client-routes to: %v\n", b)
|
||||
}
|
||||
case "disable-server-routes":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableServerRoutes: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-server-routes: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-server-routes to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-server-routes to: %v\n", b)
|
||||
}
|
||||
case "disable-dns":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableDNS: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-dns: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-dns to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-dns to: %v\n", b)
|
||||
}
|
||||
case "disable-firewall":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableFirewall: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-firewall: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-firewall to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-firewall to: %v\n", b)
|
||||
}
|
||||
case "block-lan-access":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, BlockLANAccess: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set block-lan-access: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set block-lan-access to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set block-lan-access to: %v\n", b)
|
||||
}
|
||||
case "block-inbound":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, BlockInbound: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set block-inbound: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set block-inbound to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set block-inbound to: %v\n", b)
|
||||
}
|
||||
case "enable-lazy-connection":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, LazyConnectionEnabled: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set enable-lazy-connection: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set enable-lazy-connection to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set enable-lazy-connection to: %v\n", b)
|
||||
}
|
||||
case "wireguard-port":
|
||||
p, err := parseUint16(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pi := int(p)
|
||||
input := internal.ConfigInput{ConfigPath: configPath, WireguardPort: &pi}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set wireguard-port: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set wireguard-port to: %d\n", p)
|
||||
} else {
|
||||
fmt.Printf("Set wireguard-port to: %d\n", p)
|
||||
}
|
||||
case "dns-router-interval":
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration: %v", err)
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DNSRouteInterval: &d}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set dns-router-interval: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set dns-router-interval to: %s\n", d)
|
||||
} else {
|
||||
fmt.Printf("Set dns-router-interval to: %s\n", d)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown setting: %s", setting)
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
cmd.Println("Configuration updated successfully.")
|
||||
} else {
|
||||
fmt.Println("Configuration updated successfully.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseBool(val string) (bool, error) {
|
||||
v := strings.ToLower(val)
|
||||
if v == "true" || v == "1" {
|
||||
return true, nil
|
||||
}
|
||||
if v == "false" || v == "0" {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("invalid boolean value: %s", val)
|
||||
}
|
||||
|
||||
func parseUint16(val string) (uint16, error) {
|
||||
var p uint16
|
||||
_, err := fmt.Sscanf(val, "%d", &p)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid uint16 value: %s", val)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
162
client/cmd/set_test.go
Normal file
162
client/cmd/set_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetCommand_AllSettings(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "config.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
// Write empty JSON object to the config file to avoid JSON parse errors
|
||||
_, err = tempFile.WriteString("{}")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
|
||||
configPath = tempFile.Name()
|
||||
|
||||
tests := []struct {
|
||||
setting string
|
||||
value string
|
||||
verify func(*testing.T, *internal.Config)
|
||||
wantErr bool
|
||||
}{
|
||||
{"management-url", "https://test.mgmt:443", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "https://test.mgmt:443", c.ManagementURL.String())
|
||||
}, false},
|
||||
{"admin-url", "https://test.admin:443", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "https://test.admin:443", c.AdminURL.String())
|
||||
}, false},
|
||||
{"interface-name", "utun99", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "utun99", c.WgIface)
|
||||
}, false},
|
||||
{"external-ip-map", "12.34.56.78,12.34.56.79", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, []string{"12.34.56.78", "12.34.56.79"}, c.NATExternalIPs)
|
||||
}, false},
|
||||
{"extra-iface-blacklist", "eth1,eth2", func(t *testing.T, c *internal.Config) {
|
||||
require.Contains(t, c.IFaceBlackList, "eth1")
|
||||
require.Contains(t, c.IFaceBlackList, "eth2")
|
||||
}, false},
|
||||
{"dns-resolver-address", "127.0.0.1:5053", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "127.0.0.1:5053", c.CustomDNSAddress)
|
||||
}, false},
|
||||
{"extra-dns-labels", "vpc1,mgmt1", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, strings.Contains(c.DNSLabels.SafeString(), "vpc1"))
|
||||
require.True(t, strings.Contains(c.DNSLabels.SafeString(), "mgmt1"))
|
||||
}, false},
|
||||
{"preshared-key", "testkey", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "testkey", c.PreSharedKey)
|
||||
}, false},
|
||||
{"enable-rosenpass", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.RosenpassEnabled)
|
||||
}, false},
|
||||
{"rosenpass-permissive", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.RosenpassPermissive)
|
||||
}, false},
|
||||
{"allow-server-ssh", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.NotNil(t, c.ServerSSHAllowed)
|
||||
require.True(t, *c.ServerSSHAllowed)
|
||||
}, false},
|
||||
{"network-monitor", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.NotNil(t, c.NetworkMonitor)
|
||||
require.False(t, *c.NetworkMonitor)
|
||||
}, false},
|
||||
{"disable-auto-connect", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.DisableAutoConnect)
|
||||
}, false},
|
||||
{"disable-client-routes", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.DisableClientRoutes)
|
||||
}, false},
|
||||
{"disable-server-routes", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.DisableServerRoutes)
|
||||
}, false},
|
||||
{"disable-dns", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.DisableDNS)
|
||||
}, false},
|
||||
{"disable-firewall", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.DisableFirewall)
|
||||
}, false},
|
||||
{"block-lan-access", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.BlockLANAccess)
|
||||
}, false},
|
||||
{"block-inbound", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.BlockInbound)
|
||||
}, false},
|
||||
{"enable-lazy-connection", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.LazyConnectionEnabled)
|
||||
}, false},
|
||||
{"wireguard-port", "51820", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, 51820, c.WgPort)
|
||||
}, false},
|
||||
{"dns-router-interval", "2m", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, 2*time.Minute, c.DNSRouteInterval)
|
||||
}, false},
|
||||
// Invalid cases
|
||||
{"enable-rosenpass", "notabool", nil, true},
|
||||
{"wireguard-port", "notanint", nil, true},
|
||||
{"dns-router-interval", "notaduration", nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.setting+"="+tt.value, func(t *testing.T) {
|
||||
args := []string{tt.setting, tt.value}
|
||||
err := setFunc(nil, args)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommand_EnvVars(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "config.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
// Write empty JSON object to the config file to avoid JSON parse errors
|
||||
_, err = tempFile.WriteString("{}")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
|
||||
configPath = tempFile.Name()
|
||||
|
||||
os.Setenv("NB_INTERFACE_NAME", "utun77")
|
||||
defer os.Unsetenv("NB_INTERFACE_NAME")
|
||||
args := []string{"interface-name", "utun99"}
|
||||
err = setFunc(nil, args)
|
||||
require.NoError(t, err)
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "utun77", config.WgIface)
|
||||
|
||||
os.Unsetenv("NB_INTERFACE_NAME")
|
||||
os.Setenv("WT_INTERFACE_NAME", "utun88")
|
||||
defer os.Unsetenv("WT_INTERFACE_NAME")
|
||||
err = setFunc(nil, args)
|
||||
require.NoError(t, err)
|
||||
config, err = internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "utun88", config.WgIface)
|
||||
|
||||
os.Unsetenv("WT_INTERFACE_NAME")
|
||||
// No env var, should use CLI value
|
||||
err = setFunc(nil, args)
|
||||
require.NoError(t, err)
|
||||
config, err = internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "utun99", config.WgIface)
|
||||
}
|
||||
@@ -26,7 +26,6 @@ var (
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
connectionTypeFilter string
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -46,7 +45,6 @@ func init() {
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -91,7 +89,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
@@ -158,15 +156,6 @@ func parseFilters() error {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
|
||||
switch strings.ToLower(connectionTypeFilter) {
|
||||
case "", "p2p", "relayed":
|
||||
if strings.ToLower(connectionTypeFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var sum1, sum2 uint32
|
||||
|
||||
// Parallel processing - unroll and compute two sums simultaneously
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||
// Skip checksum field at [10:12]
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
||||
|
||||
sum := sum1 + sum2
|
||||
|
||||
// Handle remaining bytes for headers > 20 bytes
|
||||
for i := 20; i < len(header)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
|
||||
if len(header)%2 == 1 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
|
||||
// Optimized carry fold - single iteration handles most cases
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
|
||||
// Process 16 bytes at once with 4 parallel accumulators
|
||||
for i <= len(data)-16 {
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
||||
i += 16
|
||||
}
|
||||
|
||||
sum := sum1 + sum2 + sum3 + sum4
|
||||
|
||||
// Handle remaining bytes
|
||||
for i < len(data)-1 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
i += 2
|
||||
}
|
||||
|
||||
if len(data)%2 == 1 {
|
||||
sum += uint32(data[len(data)-1]) << 8
|
||||
}
|
||||
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
reverse: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
delete(b.reverse, translated)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||
}
|
||||
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
}
|
||||
|
||||
m.dnatMappings[originalAddr] = translatedAddr
|
||||
m.dnatBiMap.set(originalAddr, translatedAddr)
|
||||
|
||||
if len(m.dnatMappings) == 1 {
|
||||
m.dnatEnabled.Store(true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||
}
|
||||
|
||||
delete(m.dnatMappings, originalAddr)
|
||||
m.dnatBiMap.delete(originalAddr)
|
||||
if len(m.dnatMappings) == 0 {
|
||||
m.dnatEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
translated, exists := m.dnatBiMap.getTranslated(addr)
|
||||
m.dnatMutex.RUnlock()
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
||||
m.dnatMutex.RUnlock()
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
if oldChecksum == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
icmpData := packetData[icmpStart:]
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||
checksum := icmpChecksum(icmpData)
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||
}
|
||||
if len(oldBytes)%2 == 1 {
|
||||
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||
}
|
||||
|
||||
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||
}
|
||||
if len(newBytes)%2 == 1 {
|
||||
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||
}
|
||||
}
|
||||
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
||||
func BenchmarkDNATTranslation(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_with_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
description: "TCP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "tcp_without_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
description: "TCP packet without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_with_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
description: "UDP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "udp_without_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
description: "UDP packet without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "icmp_with_dnat",
|
||||
proto: layers.IPProtocolICMPv4,
|
||||
setupDNAT: true,
|
||||
description: "ICMP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "icmp_without_dnat",
|
||||
proto: layers.IPProtocolICMPv4,
|
||||
setupDNAT: false,
|
||||
description: "ICMP packet without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup DNAT mapping if needed
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Create test packets
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||
|
||||
// Pre-establish connection for reverse DNAT test
|
||||
if sc.setupDNAT {
|
||||
manager.filterOutbound(outboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
// Benchmark outbound DNAT translation
|
||||
b.Run("outbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time since translation modifies it
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark inbound reverse DNAT translation
|
||||
if sc.setupDNAT {
|
||||
b.Run("inbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time since translation modifies it
|
||||
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup multiple DNAT mappings
|
||||
numMappings := 100
|
||||
originalIPs := make([]netip.Addr, numMappings)
|
||||
translatedIPs := make([]netip.Addr, numMappings)
|
||||
|
||||
for i := 0; i < numMappings; i++ {
|
||||
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
// Pre-generate packets
|
||||
outboundPackets := make([][]byte, numMappings)
|
||||
inboundPackets := make([][]byte, numMappings)
|
||||
for i := 0; i < numMappings; i++ {
|
||||
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
||||
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||
// Establish connections
|
||||
manager.filterOutbound(outboundPackets[i], 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("concurrent_outbound", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
idx := i % numMappings
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("concurrent_inbound", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
idx := i % numMappings
|
||||
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||
manager.filterInbound(packet, 0)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
||||
func BenchmarkDNATScaling(b *testing.B) {
|
||||
mappingCounts := []int{1, 10, 100, 1000}
|
||||
|
||||
for _, count := range mappingCounts {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup DNAT mappings
|
||||
for i := 0; i < count; i++ {
|
||||
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Test with the last mapping added (worst case for lookup)
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
||||
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP.AsSlice(),
|
||||
DstIP: dstIP.AsSlice(),
|
||||
Protocol: proto,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = udp
|
||||
case layers.IPProtocolICMPv4:
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||
}
|
||||
transportLayer = icmp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||
require.NoError(tb, err)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
||||
func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
// Create test data for checksum calculations
|
||||
testData := make([]byte, 64) // Typical packet size for checksum testing
|
||||
for i := range testData {
|
||||
testData[i] = byte(i)
|
||||
}
|
||||
|
||||
b.Run("ipv4_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("icmp_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = icmpChecksum(testData)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("incremental_update", func(b *testing.B) {
|
||||
oldBytes := []byte{192, 168, 1, 100}
|
||||
newBytes := []byte{10, 0, 0, 100}
|
||||
oldChecksum := uint16(0x1234)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time to isolate allocation testing
|
||||
testPacket := make([]byte, len(packet))
|
||||
copy(testPacket, packet)
|
||||
|
||||
// Parse the packet fresh each time to get a clean decoder
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
manager.translateOutboundDNAT(testPacket, d)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
||||
func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||
// Create a test packet
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
||||
|
||||
b.Run("direct_byte_access", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Direct extraction from packet bytes
|
||||
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decoder_extraction", func(b *testing.B) {
|
||||
// Create decoder once for comparison
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Extract using decoder (traditional method)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
_ = dst
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
||||
func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
// Create test IPv4 header (20 bytes)
|
||||
header := make([]byte, 20)
|
||||
for i := range header {
|
||||
header[i] = byte(i)
|
||||
}
|
||||
// Clear checksum field
|
||||
header[10] = 0
|
||||
header[11] = 0
|
||||
|
||||
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ipv4Checksum(header)
|
||||
}
|
||||
})
|
||||
|
||||
// Test incremental checksum updates
|
||||
oldIP := []byte{192, 168, 1, 100}
|
||||
newIP := []byte{10, 0, 0, 100}
|
||||
oldChecksum := uint16(0x1234)
|
||||
|
||||
b.Run("optimized_incremental_update", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
// Add DNAT mapping
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
||||
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
||||
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test outbound DNAT translation
|
||||
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
originalOutbound := make([]byte, len(outboundPacket))
|
||||
copy(originalOutbound, outboundPacket)
|
||||
|
||||
// Process outbound packet (should translate destination)
|
||||
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, translated, "Outbound packet should be translated")
|
||||
|
||||
// Verify destination IP was changed
|
||||
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
||||
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
||||
|
||||
// Test inbound reverse DNAT translation
|
||||
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
||||
originalInbound := make([]byte, len(inboundPacket))
|
||||
copy(originalInbound, inboundPacket)
|
||||
|
||||
// Process inbound packet (should reverse translate source)
|
||||
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, reversed, "Inbound packet should be reverse translated")
|
||||
|
||||
// Verify source IP was changed back to original
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
||||
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
||||
|
||||
// Test that checksums are recalculated correctly
|
||||
if tc.protocol != layers.IPProtocolICMPv4 {
|
||||
// For TCP/UDP, verify the transport checksum was updated
|
||||
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
||||
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// parsePacket helper to create a decoder for testing
|
||||
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
t.Helper()
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
// Test adding mapping
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify mapping exists
|
||||
result, exists := manager.getDNATTranslation(originalIP)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, translatedIP, result)
|
||||
|
||||
// Test reverse lookup
|
||||
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, originalIP, reverseResult)
|
||||
|
||||
// Test removing mapping
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify mapping no longer exists
|
||||
_, exists = manager.getDNATTranslation(originalIP)
|
||||
require.False(t, exists)
|
||||
|
||||
_, exists = manager.findReverseDNATMapping(translatedIP)
|
||||
require.False(t, exists)
|
||||
|
||||
// Test error cases
|
||||
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
||||
require.Error(t, err, "Should reject invalid original IP")
|
||||
|
||||
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
||||
require.Error(t, err, "Should reject invalid translated IP")
|
||||
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
@@ -104,12 +104,6 @@ type Manager struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
// Internal 1:1 DNAT
|
||||
dnatEnabled atomic.Bool
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -195,7 +189,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -526,6 +519,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
@@ -572,14 +581,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FilterOutBound filters outgoing packets
|
||||
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
||||
return m.filterOutbound(packetData, size)
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
}
|
||||
|
||||
// FilterInbound filters incoming packets
|
||||
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
||||
return m.filterInbound(packetData, size)
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||
return m.dropFilter(packetData, size)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
@@ -587,7 +596,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -609,8 +618,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// for netflow we keep track even if the firewall is stateless
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -714,9 +723,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
||||
return false
|
||||
}
|
||||
|
||||
// filterInbound implements filtering logic for incoming packets.
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -738,15 +747,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
return false
|
||||
}
|
||||
@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(inbound, 0)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.filterOutbound(testOut, 0)
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(testIn, 0)
|
||||
manager.dropFilter(testIn, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(inbound, 0)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.filterOutbound(syn, 0)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.filterInbound(synack, 0)
|
||||
manager.dropFilter(synack, 0)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.filterOutbound(ack, 0)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(inbound, 0)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.filterOutbound(syn, 0)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.filterInbound(synack, 0)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.filterOutbound(ack, 0)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.filterOutbound(outPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.filterInbound(inPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.filterOutbound(p.syn, 0)
|
||||
manager.filterInbound(p.synAck, 0)
|
||||
manager.filterOutbound(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
// Data transfer
|
||||
manager.filterOutbound(p.request, 0)
|
||||
manager.filterInbound(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
// Connection teardown
|
||||
manager.filterOutbound(p.finClient, 0)
|
||||
manager.filterInbound(p.ackServer, 0)
|
||||
manager.filterInbound(p.finServer, 0)
|
||||
manager.filterOutbound(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.filterOutbound(syn, 0)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.filterInbound(synack, 0)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.filterOutbound(ack, 0)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.filterOutbound(outPackets[connIdx], 0)
|
||||
manager.filterInbound(inPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.filterOutbound(p.syn, 0)
|
||||
manager.filterInbound(p.synAck, 0)
|
||||
manager.filterOutbound(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
manager.filterOutbound(p.request, 0)
|
||||
manager.filterInbound(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
manager.filterOutbound(p.finClient, 0)
|
||||
manager.filterInbound(p.ackServer, 0)
|
||||
manager.filterInbound(p.finServer, 0)
|
||||
manager.filterOutbound(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.filterInbound(buf.Bytes(), 0) {
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.filterOutbound(buf.Bytes(), 0)
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.filterOutbound(buf.Bytes(), 0)
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
const (
|
||||
saveFrequency = int64(5 * time.Second)
|
||||
)
|
||||
|
||||
type PeerRecord struct {
|
||||
Address netip.AddrPort
|
||||
LastActivity atomic.Int64 // UnixNano timestamp
|
||||
}
|
||||
|
||||
type ActivityRecorder struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*PeerRecord // publicKey to PeerRecord map
|
||||
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
|
||||
}
|
||||
|
||||
func NewActivityRecorder() *ActivityRecorder {
|
||||
return &ActivityRecorder{
|
||||
peers: make(map[string]*PeerRecord),
|
||||
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastActivities returns a snapshot of peer last activity
|
||||
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
activities := make(map[string]monotime.Time, len(r.peers))
|
||||
for key, record := range r.peers {
|
||||
monoTime := record.LastActivity.Load()
|
||||
activities[key] = monotime.Time(monoTime)
|
||||
}
|
||||
return activities
|
||||
}
|
||||
|
||||
// UpsertAddress adds or updates the address for a publicKey
|
||||
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var record *PeerRecord
|
||||
record, exists := r.peers[publicKey]
|
||||
if exists {
|
||||
delete(r.addrToPeer, record.Address)
|
||||
record.Address = address
|
||||
} else {
|
||||
record = &PeerRecord{
|
||||
Address: address,
|
||||
}
|
||||
record.LastActivity.Store(int64(monotime.Now()))
|
||||
r.peers[publicKey] = record
|
||||
}
|
||||
|
||||
r.addrToPeer[address] = record
|
||||
}
|
||||
|
||||
func (r *ActivityRecorder) Remove(publicKey string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if record, exists := r.peers[publicKey]; exists {
|
||||
delete(r.addrToPeer, record.Address)
|
||||
delete(r.peers, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
// record updates LastActivity for the given address using atomic store
|
||||
func (r *ActivityRecorder) record(address netip.AddrPort) {
|
||||
r.mu.RLock()
|
||||
record, ok := r.addrToPeer[address]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
log.Warnf("could not find record for address %s", address)
|
||||
return
|
||||
}
|
||||
|
||||
now := int64(monotime.Now())
|
||||
last := record.LastActivity.Load()
|
||||
if now-last < saveFrequency {
|
||||
return
|
||||
}
|
||||
|
||||
_ = record.LastActivity.CompareAndSwap(last, now)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
||||
peer := "peer1"
|
||||
ar := NewActivityRecorder()
|
||||
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
|
||||
activities := ar.GetLastActivities()
|
||||
|
||||
p, ok := activities[peer]
|
||||
if !ok {
|
||||
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
||||
}
|
||||
|
||||
if monotime.Since(p) > 5*time.Second {
|
||||
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||
func init() {
|
||||
listener := nbnet.NewListener()
|
||||
if listener.ListenConfig.Control != nil {
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
|
||||
}
|
||||
}
|
||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -16,7 +15,6 @@ import (
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -53,24 +51,22 @@ type ICEBind struct {
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
activityRecorder *ActivityRecorder
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -104,10 +100,6 @@ func (s *ICEBind) Close() error {
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||
return s.activityRecorder
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
@@ -154,7 +146,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapUDPConn(conn),
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
@@ -207,11 +199,6 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
|
||||
if isTransportPkg(msg.Buffers, msg.N) {
|
||||
s.activityRecorder.record(addrPort)
|
||||
}
|
||||
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
@@ -270,13 +257,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
copy(buffs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||
|
||||
if isTransportPkg(buffs, sizes[0]) {
|
||||
if ep, ok := eps[0].(*Endpoint); ok {
|
||||
c.activityRecorder.record(ep.AddrPort)
|
||||
}
|
||||
}
|
||||
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
@@ -292,19 +272,3 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
func isTransportPkg(buffers [][]byte, n int) bool {
|
||||
// The first buffer should contain at least 4 bytes for type
|
||||
if len(buffers[0]) < 4 {
|
||||
return true
|
||||
}
|
||||
|
||||
// WireGuard packet type is a little-endian uint32 at start
|
||||
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
|
||||
|
||||
// Check if packetType matches known WireGuard message types
|
||||
if packetType == 4 && n > 32 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -296,20 +296,14 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
return
|
||||
}
|
||||
|
||||
var allAddresses []string
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
for _, c := range removedConns {
|
||||
addresses := c.getAddresses()
|
||||
allAddresses = append(allAddresses, addresses...)
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
for _, addr := range allAddresses {
|
||||
delete(m.addressMap, addr)
|
||||
}
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
for _, addr := range allAddresses {
|
||||
m.notifyAddressRemoval(addr)
|
||||
for _, addr := range addresses {
|
||||
delete(m.addressMap, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,13 +351,14 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
existing, ok := m.addressMap[addr]
|
||||
if !ok {
|
||||
existing = []*udpMuxedConn{}
|
||||
}
|
||||
existing = append(existing, conn)
|
||||
m.addressMap[addr] = existing
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
@@ -391,12 +386,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
m.addressMapMu.RLock()
|
||||
m.addressMapMu.Lock()
|
||||
var destinationConnList []*udpMuxedConn
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
destinationConnList = append(destinationConnList, storedConns...)
|
||||
}
|
||||
m.addressMapMu.RUnlock()
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
wrapped, ok := m.params.UDPConn.(*UDPConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nbnetConn.RemoveAddress(addr)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package bind
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
|
||||
// wrap UDP connection, process server reflexive messages
|
||||
// before they are passed to the UDPMux connection handler (connWorker)
|
||||
m.params.UDPConn = &UDPConn{
|
||||
m.params.UDPConn = &udpConn{
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
@@ -70,6 +70,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
udpMuxParams := UDPMuxParams{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
@@ -113,8 +114,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type UDPConn struct {
|
||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type udpConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
@@ -124,12 +125,7 @@ type UDPConn struct {
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
return u.PacketConn
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if u.filterFn == nil {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
@@ -141,21 +137,21 @@ func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
var zeroKey wgtypes.Key
|
||||
@@ -278,7 +276,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "nb0"
|
||||
const WgInterfaceDefault = "wt0"
|
||||
|
||||
@@ -16,8 +16,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -38,18 +36,16 @@ const (
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
|
||||
type WGUSPConfigurer struct {
|
||||
device *device.Device
|
||||
deviceName string
|
||||
activityRecorder *bind.ActivityRecorder
|
||||
device *device.Device
|
||||
deviceName string
|
||||
|
||||
uapiListener net.Listener
|
||||
}
|
||||
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
||||
wgCfg := &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
}
|
||||
wgCfg.startUAPI()
|
||||
return wgCfg
|
||||
@@ -91,19 +87,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return ipcErr
|
||||
}
|
||||
|
||||
if endpoint != nil {
|
||||
addr, err := netip.ParseAddr(endpoint.IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||
}
|
||||
return nil
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
@@ -120,10 +104,7 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
|
||||
|
||||
c.activityRecorder.Remove(peerKey)
|
||||
return ipcErr
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
@@ -224,10 +205,6 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||
return parseStatus(c.deviceName, ipcStr)
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return c.activityRecorder.GetLastActivities()
|
||||
}
|
||||
|
||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||
func (t *WGUSPConfigurer) startUAPI() {
|
||||
var err error
|
||||
|
||||
@@ -79,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
|
||||
// PacketFilter interface for firewall abilities
|
||||
type PacketFilter interface {
|
||||
// FilterOutbound filter outgoing packets from host to external destinations
|
||||
FilterOutbound(packetData []byte, size int) bool
|
||||
// DropOutgoing filter outgoing packets from host to external destinations
|
||||
DropOutgoing(packetData []byte, size int) bool
|
||||
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
// DropIncoming filter incoming packets from external sources to host
|
||||
DropIncoming(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
return 1, nil
|
||||
})
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
|
||||
@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
|
||||
@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
type WGConfigurer interface {
|
||||
@@ -20,5 +19,4 @@ type WGConfigurer interface {
|
||||
Close()
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -30,11 +29,6 @@ const (
|
||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrIfaceNotFound is returned when the WireGuard interface is not found
|
||||
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
|
||||
)
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
Free() error
|
||||
@@ -123,9 +117,6 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
|
||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
@@ -135,9 +126,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
|
||||
return w.configurer.RemovePeer(peerKey)
|
||||
@@ -147,9 +135,6 @@ func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||
return w.configurer.AddAllowedIP(peerKey, allowedIP)
|
||||
@@ -159,9 +144,6 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
|
||||
@@ -232,29 +214,10 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes
|
||||
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
if w.configurer == nil {
|
||||
return nil, ErrIfaceNotFound
|
||||
}
|
||||
return w.configurer.GetStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) LastActivities() map[string]monotime.Time {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.configurer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.configurer.LastActivities()
|
||||
|
||||
}
|
||||
|
||||
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||
if w.configurer == nil {
|
||||
return nil, ErrIfaceNotFound
|
||||
}
|
||||
|
||||
return w.configurer.FullStats()
|
||||
}
|
||||
|
||||
|
||||
@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
|
||||
@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
|
||||
@@ -41,12 +41,9 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
}
|
||||
t.tundev = nsTunDev
|
||||
|
||||
var skipProxy bool
|
||||
if val := os.Getenv(EnvSkipProxy); val != "" {
|
||||
skipProxy, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||
}
|
||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||
}
|
||||
if skipProxy {
|
||||
return nsTunDev, tunNet, nil
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
type ProxyBind struct {
|
||||
@@ -29,17 +28,6 @@ type ProxyBind struct {
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||
p := &ProxyBind{
|
||||
Bind: bind,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn adds a new connection to the bind.
|
||||
@@ -66,10 +54,6 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||
p.closeListener.SetCloseListener(disconnected)
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
@@ -112,9 +96,6 @@ func (p *ProxyBind) close() error {
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closeListener.SetCloseListener(nil)
|
||||
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
@@ -141,7 +122,6 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
p.closeListener.Notify()
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
@@ -28,15 +26,6 @@ type ProxyWrapper struct {
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
||||
return &ProxyWrapper{
|
||||
WgeBPFProxy: WgeBPFProxy,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
@@ -54,10 +43,6 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||
p.closeListener.SetCloseListener(disconnected)
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
@@ -92,8 +77,6 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
|
||||
e.cancel()
|
||||
|
||||
e.closeListener.SetCloseListener(nil)
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
@@ -134,7 +117,6 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
if ctx.Err() != nil {
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
p.closeListener.Notify()
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
|
||||
@@ -36,8 +36,9 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||
}
|
||||
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
|
||||
return &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: w.ebpfProxy,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
|
||||
@@ -20,7 +20,9 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
||||
}
|
||||
|
||||
func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind)
|
||||
return &proxyBind.ProxyBind{
|
||||
Bind: w.bind,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package listener
|
||||
|
||||
type CloseListener struct {
|
||||
listener func()
|
||||
}
|
||||
|
||||
func NewCloseListener() *CloseListener {
|
||||
return &CloseListener{}
|
||||
}
|
||||
|
||||
func (c *CloseListener) SetCloseListener(listener func()) {
|
||||
c.listener = listener
|
||||
}
|
||||
|
||||
func (c *CloseListener) Notify() {
|
||||
if c.listener != nil {
|
||||
c.listener()
|
||||
}
|
||||
}
|
||||
@@ -12,5 +12,4 @@ type Proxy interface {
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
CloseConn() error
|
||||
SetDisconnectListener(disconnected func())
|
||||
}
|
||||
|
||||
@@ -98,7 +98,9 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
proxyWrapper := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
}
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
cerrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
// WGUDPProxy proxies
|
||||
@@ -29,8 +28,6 @@ type WGUDPProxy struct {
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||
@@ -38,7 +35,6 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
return p
|
||||
}
|
||||
@@ -71,10 +67,6 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
||||
return endpointUdpAddr
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
|
||||
p.closeListener.SetCloseListener(disconnected)
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUDPProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
@@ -119,8 +111,6 @@ func (p *WGUDPProxy) close() error {
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closeListener.SetCloseListener(nil)
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
@@ -151,7 +141,6 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
p.closeListener.Notify()
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
)
|
||||
|
||||
var defaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -25,11 +26,11 @@ import (
|
||||
//
|
||||
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||
type ConnMgr struct {
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
enabledLocally bool
|
||||
rosenpassEnabled bool
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
dispatcher *dispatcher.ConnectionDispatcher
|
||||
enabledLocally bool
|
||||
|
||||
lazyConnMgr *manager.Manager
|
||||
|
||||
@@ -38,12 +39,12 @@ type ConnMgr struct {
|
||||
lazyCtxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
|
||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||
e := &ConnMgr{
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
rosenpassEnabled: engineConfig.RosenpassEnabled,
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
dispatcher: dispatcher,
|
||||
}
|
||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||
e.enabledLocally = true
|
||||
@@ -63,11 +64,6 @@ func (e *ConnMgr) Start(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
return
|
||||
}
|
||||
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
}
|
||||
@@ -87,12 +83,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warnf("lazy connection manager is enabled by management feature flag")
|
||||
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
return e.addPeersToLazyConnManager()
|
||||
@@ -142,7 +133,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
|
||||
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||
}
|
||||
|
||||
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
|
||||
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
|
||||
for _, peerID := range added {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
@@ -184,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
|
||||
PeerConnID: conn.ConnID(),
|
||||
Log: conn.Log,
|
||||
}
|
||||
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
@@ -210,7 +201,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer conn.Close(false)
|
||||
defer conn.Close()
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
@@ -220,27 +211,23 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||
conn.Log.Infof("removed peer from lazy conn manager")
|
||||
}
|
||||
|
||||
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return conn, true
|
||||
}
|
||||
|
||||
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
|
||||
conn.Log.Infof("activated peer from inactive state")
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
|
||||
// If locally the lazy connection is disabled, we force the peer connection open.
|
||||
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
|
||||
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
|
||||
return conn, true
|
||||
}
|
||||
|
||||
func (e *ConnMgr) Close() {
|
||||
@@ -257,7 +244,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
|
||||
cfg := manager.Config{
|
||||
InactivityThreshold: inactivityThresholdEnv(),
|
||||
}
|
||||
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
|
||||
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
|
||||
|
||||
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
@@ -288,7 +275,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
|
||||
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||
}
|
||||
|
||||
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
|
||||
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
|
||||
}
|
||||
|
||||
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||
|
||||
@@ -167,7 +167,6 @@ type BundleGenerator struct {
|
||||
anonymize bool
|
||||
clientStatus string
|
||||
includeSystemInfo bool
|
||||
logFileCount uint32
|
||||
|
||||
archive *zip.Writer
|
||||
}
|
||||
@@ -176,7 +175,6 @@ type BundleConfig struct {
|
||||
Anonymize bool
|
||||
ClientStatus string
|
||||
IncludeSystemInfo bool
|
||||
LogFileCount uint32
|
||||
}
|
||||
|
||||
type GeneratorDependencies struct {
|
||||
@@ -187,12 +185,6 @@ type GeneratorDependencies struct {
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
// Default to 1 log file for backward compatibility when 0 is provided
|
||||
logFileCount := cfg.LogFileCount
|
||||
if logFileCount == 0 {
|
||||
logFileCount = 1
|
||||
}
|
||||
|
||||
return &BundleGenerator{
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
|
||||
@@ -204,7 +196,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
anonymize: cfg.Anonymize,
|
||||
clientStatus: cfg.ClientStatus,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
logFileCount: logFileCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -570,8 +561,32 @@ func (g *BundleGenerator) addLogfile() error {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
// add rotated log files based on logFileCount
|
||||
g.addRotatedLogFiles(logDir)
|
||||
// add latest rotated log file
|
||||
pattern := filepath.Join(logDir, "client-*.log.gz")
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("failed to glob rotated logs: %v", err)
|
||||
} else if len(files) > 0 {
|
||||
// pick the file with the latest ModTime
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
fi, err := os.Stat(files[i])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
|
||||
return false
|
||||
}
|
||||
fj, err := os.Stat(files[j])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
|
||||
return false
|
||||
}
|
||||
return fi.ModTime().Before(fj.ModTime())
|
||||
})
|
||||
latest := files[len(files)-1]
|
||||
name := filepath.Base(latest)
|
||||
if err := g.addSingleLogFileGz(latest, name); err != nil {
|
||||
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
stdErrLogPath := filepath.Join(logDir, errorLogFile)
|
||||
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
||||
@@ -655,52 +670,6 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount
|
||||
func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
|
||||
if g.logFileCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
pattern := filepath.Join(logDir, "client-*.log.gz")
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("failed to glob rotated logs: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// sort files by modification time (newest first)
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
fi, err := os.Stat(files[i])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
|
||||
return false
|
||||
}
|
||||
fj, err := os.Stat(files[j])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
|
||||
return false
|
||||
}
|
||||
return fi.ModTime().After(fj.ModTime())
|
||||
})
|
||||
|
||||
// include up to logFileCount rotated files
|
||||
maxFiles := int(g.logFileCount)
|
||||
if maxFiles > len(files) {
|
||||
maxFiles = len(files)
|
||||
}
|
||||
|
||||
for i := 0; i < maxFiles; i++ {
|
||||
name := filepath.Base(files[i])
|
||||
if err := g.addSingleLogFileGz(files[i], name); err != nil {
|
||||
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
|
||||
@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
@@ -61,6 +62,7 @@ import (
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -137,6 +139,9 @@ type Engine struct {
|
||||
|
||||
connMgr *ConnMgr
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
|
||||
// rpManager is a Rosenpass manager
|
||||
rpManager *rosenpass.Manager
|
||||
|
||||
@@ -170,7 +175,8 @@ type Engine struct {
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
|
||||
statusRecorder *peer.Status
|
||||
statusRecorder *peer.Status
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
|
||||
firewall firewallManager.Manager
|
||||
routeManager routemanager.Manager
|
||||
@@ -377,13 +383,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.stateManager.Start()
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
@@ -400,13 +400,16 @@ func (e *Engine) Start() error {
|
||||
InitialRoutes: initialRoutes,
|
||||
StateManager: e.stateManager,
|
||||
DNSServer: dnsServer,
|
||||
DNSFeatureFlag: dnsFeatureFlag,
|
||||
PeerStore: e.peerStore,
|
||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||
})
|
||||
if err := e.routeManager.Init(); err != nil {
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
e.beforePeerHook = beforePeerHook
|
||||
e.afterPeerHook = afterPeerHook
|
||||
}
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
@@ -448,7 +451,9 @@ func (e *Engine) Start() error {
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
@@ -483,9 +488,9 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
|
||||
if e.config.BlockLANAccess {
|
||||
@@ -1004,6 +1009,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
@@ -1014,7 +1021,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
@@ -1249,10 +1255,14 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
}
|
||||
|
||||
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||
conn.Close(false)
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1292,12 +1302,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
PeerConnDispatcher: e.peerConnDispatcher,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1320,16 +1331,11 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
}
|
||||
|
||||
msgType := msg.GetBody().GetType()
|
||||
if msgType != sProto.Body_GO_IDLE {
|
||||
e.connMgr.ActivatePeer(e.ctx, conn)
|
||||
}
|
||||
|
||||
switch msg.GetBody().Type {
|
||||
case sProto.Body_OFFER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
@@ -1386,8 +1392,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
|
||||
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||
case sProto.Body_MODE:
|
||||
case sProto.Body_GO_IDLE:
|
||||
e.connMgr.DeactivatePeer(conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1485,12 +1489,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
if runtime.GOOS != "android" {
|
||||
// nolint:nilnil
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
info := system.GetInfo(e.ctx)
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
@@ -1507,12 +1506,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||
dnsFeatureFlag := toDNSFeatureFlag(netMap)
|
||||
return routes, &dnsCfg, dnsFeatureFlag, nil
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
|
||||
func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
@@ -1560,14 +1558,18 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
// due to tests where we are using a mocked version of the DNS server
|
||||
if e.dnsServer != nil {
|
||||
return e.dnsServer, nil
|
||||
return nil, e.dnsServer, nil
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
routes, dnsConfig, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
||||
e.ctx,
|
||||
e.wgInterface,
|
||||
@@ -1578,19 +1580,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
e.config.DisableDNS,
|
||||
)
|
||||
go e.mobileDep.DnsReadyListener.OnReady()
|
||||
return dnsServer, nil
|
||||
return routes, dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
return nil, dnsServer, nil
|
||||
|
||||
default:
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return dnsServer, nil
|
||||
return nil, dnsServer, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -52,7 +53,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
@@ -97,7 +97,6 @@ type MockWGIface struct {
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
@@ -188,13 +187,6 @@ func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
||||
if m.LastActivitiesFunc != nil {
|
||||
return m.LastActivitiesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
@@ -400,7 +392,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
StatusRecorder: engine.statusRecorder,
|
||||
RelayManager: relayMgr,
|
||||
})
|
||||
err = engine.routeManager.Init()
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -412,7 +404,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
type testCase struct {
|
||||
@@ -801,7 +793,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
|
||||
engine.routeManager = mockRouteManager
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
@@ -999,7 +991,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
engine.dnsServer = mockDNSServer
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
@@ -1393,7 +1385,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("nb%d", i)
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
@@ -1481,10 +1473,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
@@ -1494,7 +1482,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
type wgIfaceBase interface {
|
||||
@@ -39,5 +38,4 @@ type wgIfaceBase interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetNet() *netstack.Net
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||
type Listener struct {
|
||||
wgIface WgInterface
|
||||
wgIface lazyconn.WGIface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
conn *net.UDPConn
|
||||
endpoint *net.UDPAddr
|
||||
@@ -22,7 +22,7 @@ type Listener struct {
|
||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||
}
|
||||
|
||||
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
d := &Listener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
@@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() {
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
if err != nil {
|
||||
if d.isClosed.Load() {
|
||||
d.peerCfg.Log.Infof("exit from activity listener")
|
||||
d.peerCfg.Log.Debugf("exit from activity listener")
|
||||
} else {
|
||||
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
|
||||
}
|
||||
@@ -59,11 +59,9 @@ func (d *Listener) ReadPackets() {
|
||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||
continue
|
||||
}
|
||||
d.peerCfg.Log.Infof("activity detected")
|
||||
break
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
if err := d.removeEndpoint(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
@@ -73,7 +71,7 @@ func (d *Listener) ReadPackets() {
|
||||
}
|
||||
|
||||
func (d *Listener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
|
||||
d.isClosed.Store(true)
|
||||
|
||||
if err := d.conn.Close(); err != nil {
|
||||
@@ -83,6 +81,7 @@ func (d *Listener) Close() {
|
||||
}
|
||||
|
||||
func (d *Listener) removeEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,27 +1,18 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type WgInterface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
OnActivityChan chan peerid.ConnID
|
||||
|
||||
wgIface WgInterface
|
||||
wgIface lazyconn.WGIface
|
||||
|
||||
peers map[peerid.ConnID]*Listener
|
||||
done chan struct{}
|
||||
@@ -29,7 +20,7 @@ type Manager struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(wgIface WgInterface) *Manager {
|
||||
func NewManager(wgIface lazyconn.WGIface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
|
||||
75
client/internal/lazyconn/inactivity/inactivity.go
Normal file
75
client/internal/lazyconn/inactivity/inactivity.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
peer "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
|
||||
MinimumInactivityThreshold = 3 * time.Minute
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
id peer.ConnID
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
inactivityThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
|
||||
i := &Monitor{
|
||||
id: peerID,
|
||||
timer: time.NewTimer(0),
|
||||
inactivityThreshold: threshold,
|
||||
}
|
||||
i.timer.Stop()
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
defer i.timer.Stop()
|
||||
|
||||
ctx, i.cancel = context.WithCancel(ctx)
|
||||
defer func() {
|
||||
defer i.cancel()
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
select {
|
||||
case timeoutChan <- i.id:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Monitor) Stop() {
|
||||
if i.cancel == nil {
|
||||
return
|
||||
}
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
func (i *Monitor) PauseTimer() {
|
||||
i.timer.Stop()
|
||||
}
|
||||
|
||||
func (i *Monitor) ResetTimer() {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
}
|
||||
|
||||
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||
i.Stop()
|
||||
go i.Start(ctx, timeoutChan)
|
||||
}
|
||||
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
}
|
||||
|
||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
return peerid.ConnID(m)
|
||||
}
|
||||
|
||||
func TestInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseInactivityMonitor(t *testing.T) {
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
for i := 2; i > 0; i-- {
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(testTimeoutCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
testTimeoutCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
im.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPauseInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
trashHold := time.Second * 3
|
||||
im := NewInactivityMonitor(p.ConnID(), trashHold)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(ctx, timeoutChan)
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second) // grant time to start the monitor
|
||||
im.PauseTimer()
|
||||
|
||||
// check to do not receive timeout
|
||||
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
|
||||
defer thresholdCancel()
|
||||
select {
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-thresholdCtx.Done():
|
||||
// test ok
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
}
|
||||
|
||||
// test reset timer
|
||||
im.ResetTimer()
|
||||
|
||||
select {
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
// expected timeout
|
||||
}
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
const (
|
||||
checkInterval = 1 * time.Minute
|
||||
|
||||
DefaultInactivityThreshold = 15 * time.Minute
|
||||
MinimumInactivityThreshold = 1 * time.Minute
|
||||
)
|
||||
|
||||
type WgInterface interface {
|
||||
LastActivities() map[string]monotime.Time
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
inactivePeersChan chan map[string]struct{}
|
||||
|
||||
iface WgInterface
|
||||
interestedPeers map[string]*lazyconn.PeerConfig
|
||||
inactivityThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
|
||||
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
|
||||
if err != nil {
|
||||
inactivityThreshold = DefaultInactivityThreshold
|
||||
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
|
||||
}
|
||||
|
||||
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
|
||||
return &Manager{
|
||||
inactivePeersChan: make(chan map[string]struct{}, 1),
|
||||
iface: iface,
|
||||
interestedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||
inactivityThreshold: inactivityThreshold,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
|
||||
if m == nil {
|
||||
// return a nil channel that blocks forever
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.inactivePeersChan
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
|
||||
return
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("adding peer to inactivity manager")
|
||||
m.interestedPeers[peerCfg.PublicKey] = peerCfg
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(peer string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pi, ok := m.interestedPeers[peer]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pi.Log.Debugf("remove peer from inactivity manager")
|
||||
delete(m.interestedPeers, peer)
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := newTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C():
|
||||
idlePeers, err := m.checkStats()
|
||||
if err != nil {
|
||||
log.Errorf("error checking stats: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(idlePeers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m.notifyInactivePeers(ctx, idlePeers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
|
||||
select {
|
||||
case m.inactivePeersChan <- inactivePeers:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkStats() (map[string]struct{}, error) {
|
||||
lastActivities := m.iface.LastActivities()
|
||||
|
||||
idlePeers := make(map[string]struct{})
|
||||
|
||||
checkTime := time.Now()
|
||||
for peerID, peerCfg := range m.interestedPeers {
|
||||
lastActive, ok := lastActivities[peerID]
|
||||
if !ok {
|
||||
// when peer is in connecting state
|
||||
peerCfg.Log.Warnf("peer not found in wg stats")
|
||||
continue
|
||||
}
|
||||
|
||||
since := monotime.Since(lastActive)
|
||||
if since > m.inactivityThreshold {
|
||||
peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String())
|
||||
idlePeers[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return idlePeers, nil
|
||||
}
|
||||
|
||||
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
|
||||
if configuredThreshold == nil {
|
||||
return DefaultInactivityThreshold, nil
|
||||
}
|
||||
if *configuredThreshold < MinimumInactivityThreshold {
|
||||
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
|
||||
}
|
||||
return *configuredThreshold, nil
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
type mockWgInterface struct {
|
||||
lastActivities map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *mockWgInterface) LastActivities() map[string]monotime.Time {
|
||||
return m.lastActivities
|
||||
}
|
||||
|
||||
func TestPeerTriggersInactivity(t *testing.T) {
|
||||
peerID := "peer1"
|
||||
|
||||
wgMock := &mockWgInterface{
|
||||
lastActivities: map[string]monotime.Time{
|
||||
peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)),
|
||||
},
|
||||
}
|
||||
|
||||
fakeTick := make(chan time.Time, 1)
|
||||
newTicker = func(d time.Duration) Ticker {
|
||||
return &fakeTickerMock{CChan: fakeTick}
|
||||
}
|
||||
|
||||
peerLog := log.WithField("peer", peerID)
|
||||
peerCfg := &lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
Log: peerLog,
|
||||
}
|
||||
|
||||
manager := NewManager(wgMock, nil)
|
||||
manager.AddPeer(peerCfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start the manager in a goroutine
|
||||
go manager.Start(ctx)
|
||||
|
||||
// Send a tick to simulate time passage
|
||||
fakeTick <- time.Now()
|
||||
|
||||
// Check if peer appears on inactivePeersChan
|
||||
select {
|
||||
case inactivePeers := <-manager.inactivePeersChan:
|
||||
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected inactivity event, but none received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerTriggersActivity(t *testing.T) {
|
||||
peerID := "peer1"
|
||||
|
||||
wgMock := &mockWgInterface{
|
||||
lastActivities: map[string]monotime.Time{
|
||||
peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)),
|
||||
},
|
||||
}
|
||||
|
||||
fakeTick := make(chan time.Time, 1)
|
||||
newTicker = func(d time.Duration) Ticker {
|
||||
return &fakeTickerMock{CChan: fakeTick}
|
||||
}
|
||||
|
||||
peerLog := log.WithField("peer", peerID)
|
||||
peerCfg := &lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
Log: peerLog,
|
||||
}
|
||||
|
||||
manager := NewManager(wgMock, nil)
|
||||
manager.AddPeer(peerCfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start the manager in a goroutine
|
||||
go manager.Start(ctx)
|
||||
|
||||
// Send a tick to simulate time passage
|
||||
fakeTick <- time.Now()
|
||||
|
||||
// Check if peer appears on inactivePeersChan
|
||||
select {
|
||||
case <-manager.inactivePeersChan:
|
||||
t.Fatal("expected inactive peer to be marked inactive")
|
||||
case <-time.After(1 * time.Second):
|
||||
// No inactivity event should be received
|
||||
}
|
||||
}
|
||||
|
||||
// fakeTickerMock implements Ticker interface for testing
|
||||
type fakeTickerMock struct {
|
||||
CChan chan time.Time
|
||||
}
|
||||
|
||||
func (f *fakeTickerMock) C() <-chan time.Time {
|
||||
return f.CChan
|
||||
}
|
||||
|
||||
func (f *fakeTickerMock) Stop() {}
|
||||
@@ -1,24 +0,0 @@
|
||||
package inactivity
|
||||
|
||||
import "time"
|
||||
|
||||
var newTicker = func(d time.Duration) Ticker {
|
||||
return &realTicker{t: time.NewTicker(d)}
|
||||
}
|
||||
|
||||
type Ticker interface {
|
||||
C() <-chan time.Time
|
||||
Stop()
|
||||
}
|
||||
|
||||
type realTicker struct {
|
||||
t *time.Ticker
|
||||
}
|
||||
|
||||
func (r *realTicker) C() <-chan time.Time {
|
||||
return r.t.C
|
||||
}
|
||||
|
||||
func (r *realTicker) Stop() {
|
||||
r.t.Stop()
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -42,46 +43,60 @@ type Config struct {
|
||||
type Manager struct {
|
||||
engineCtx context.Context
|
||||
peerStore *peerstore.Store
|
||||
connStateDispatcher *dispatcher.ConnectionDispatcher
|
||||
inactivityThreshold time.Duration
|
||||
|
||||
connStateListener *dispatcher.ConnectionListener
|
||||
managedPeers map[string]*lazyconn.PeerConfig
|
||||
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
||||
excludes map[string]lazyconn.PeerConfig
|
||||
managedPeersMu sync.Mutex
|
||||
|
||||
activityManager *activity.Manager
|
||||
inactivityManager *inactivity.Manager
|
||||
activityManager *activity.Manager
|
||||
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
|
||||
|
||||
// Route HA group management
|
||||
// If any peer in the same HA group is active, all peers in that group should prevent going idle
|
||||
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
|
||||
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
|
||||
routesMu sync.RWMutex
|
||||
|
||||
onInactive chan peerid.ConnID
|
||||
}
|
||||
|
||||
// NewManager creates a new lazy connection manager
|
||||
// engineCtx is the context for creating peer Connection
|
||||
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
|
||||
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
|
||||
log.Infof("setup lazy connection service")
|
||||
|
||||
m := &Manager{
|
||||
engineCtx: engineCtx,
|
||||
peerStore: peerStore,
|
||||
connStateDispatcher: connStateDispatcher,
|
||||
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
||||
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
||||
excludes: make(map[string]lazyconn.PeerConfig),
|
||||
activityManager: activity.NewManager(wgIface),
|
||||
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
|
||||
peerToHAGroups: make(map[string][]route.HAUniqueID),
|
||||
haGroupToPeers: make(map[route.HAUniqueID][]string),
|
||||
onInactive: make(chan peerid.ConnID),
|
||||
}
|
||||
|
||||
if wgIface.IsUserspaceBind() {
|
||||
m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
|
||||
} else {
|
||||
log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
|
||||
if config.InactivityThreshold != nil {
|
||||
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
|
||||
m.inactivityThreshold = *config.InactivityThreshold
|
||||
} else {
|
||||
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
m.connStateListener = &dispatcher.ConnectionListener{
|
||||
OnConnected: m.onPeerConnected,
|
||||
OnDisconnected: m.onPeerDisconnected,
|
||||
}
|
||||
|
||||
connStateDispatcher.AddListener(m.connStateListener)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -116,28 +131,24 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
|
||||
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
|
||||
len(m.haGroupToPeers), len(m.peerToHAGroups))
|
||||
}
|
||||
|
||||
// Start starts the manager and listens for peer activity and inactivity events
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
defer m.close()
|
||||
|
||||
if m.inactivityManager != nil {
|
||||
go m.inactivityManager.Start(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(peerConnID)
|
||||
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
||||
m.onPeerInactivityTimedOut(peerIDs)
|
||||
m.onPeerActivity(ctx, peerConnID)
|
||||
case peerConnID := <-m.onInactive:
|
||||
m.onPeerInactivityTimedOut(ctx, peerConnID)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ExcludePeer marks peers for a permanent connection
|
||||
@@ -145,7 +156,7 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
||||
// this case, we suppose that the connection status is connected or connecting.
|
||||
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
||||
func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
|
||||
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -176,7 +187,7 @@ func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
|
||||
|
||||
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
||||
|
||||
if err := m.addActivePeer(&peerCfg); err != nil {
|
||||
if err := m.addActivePeer(ctx, peerCfg); err != nil {
|
||||
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
||||
continue
|
||||
}
|
||||
@@ -186,7 +197,7 @@ func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
|
||||
return added
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -206,6 +217,9 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: &peerCfg,
|
||||
@@ -215,7 +229,7 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
// Check if this peer should be activated because its HA group peers are active
|
||||
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
|
||||
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
|
||||
m.activateNewPeerInActiveGroup(peerCfg)
|
||||
m.activateNewPeerInActiveGroup(ctx, peerCfg)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
@@ -223,7 +237,7 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
|
||||
// AddActivePeers adds a list of peers to the lazy connection manager
|
||||
// suppose these peers was in connected or in connecting states
|
||||
func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
|
||||
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -233,7 +247,7 @@ func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.addActivePeer(&cfg); err != nil {
|
||||
if err := m.addActivePeer(ctx, cfg); err != nil {
|
||||
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -250,7 +264,7 @@ func (m *Manager) RemovePeer(peerID string) {
|
||||
|
||||
// ActivatePeer activates a peer connection when a signal message is received
|
||||
// Also activates all peers in the same HA groups as this peer
|
||||
func (m *Manager) ActivatePeer(peerID string) (found bool) {
|
||||
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
cfg, mp := m.getPeerForActivation(peerID)
|
||||
@@ -258,43 +272,15 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
cfg.Log.Infof("activate peer from inactive state by remote signal message")
|
||||
|
||||
if !m.activateSinglePeer(cfg, mp) {
|
||||
if !m.activateSinglePeer(ctx, cfg, mp) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.activateHAGroupPeers(cfg)
|
||||
m.activateHAGroupPeers(ctx, peerID)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
|
||||
// Returns nil values if the peer should be skipped
|
||||
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
|
||||
@@ -316,36 +302,41 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
|
||||
return cfg, mp
|
||||
}
|
||||
|
||||
// activateSinglePeer activates a single peer
|
||||
// return true if the peer was activated, false if it was already active
|
||||
func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
|
||||
if mp.expectedWatcher == watcherInactivity {
|
||||
// activateSinglePeer activates a single peer (internal method)
|
||||
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
|
||||
im, ok := m.inactivityMonitors[cfg.PeerConnID]
|
||||
if !ok {
|
||||
cfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return false
|
||||
}
|
||||
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
m.inactivityManager.AddPeer(cfg)
|
||||
cfg.Log.Infof("starting inactivity monitor")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
|
||||
func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
|
||||
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
|
||||
var peersToActivate []string
|
||||
|
||||
m.routesMu.RLock()
|
||||
haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
|
||||
haGroups := m.peerToHAGroups[triggerPeerID]
|
||||
|
||||
if len(haGroups) == 0 {
|
||||
m.routesMu.RUnlock()
|
||||
triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
|
||||
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
|
||||
return
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
peers := m.haGroupToPeers[haGroup]
|
||||
for _, peerID := range peers {
|
||||
if peerID != triggeredPeerCfg.PublicKey {
|
||||
if peerID != triggerPeerID {
|
||||
peersToActivate = append(peersToActivate, peerID)
|
||||
}
|
||||
}
|
||||
@@ -359,16 +350,16 @@ func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.activateSinglePeer(cfg, mp) {
|
||||
if m.activateSinglePeer(ctx, cfg, mp) {
|
||||
activatedCount++
|
||||
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
|
||||
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
if activatedCount > 0 {
|
||||
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
|
||||
activatedCount, triggeredPeerCfg.PublicKey, haGroups)
|
||||
activatedCount, triggerPeerID, haGroups)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,13 +394,13 @@ func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool)
|
||||
}
|
||||
|
||||
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
|
||||
func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
|
||||
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
|
||||
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if !m.activateSinglePeer(&peerCfg, mp) {
|
||||
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -417,19 +408,23 @@ func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
|
||||
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = peerCfg
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: peerCfg,
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherInactivity,
|
||||
}
|
||||
|
||||
m.inactivityManager.AddPeer(peerCfg)
|
||||
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -441,7 +436,12 @@ func (m *Manager) removePeer(peerID string) {
|
||||
|
||||
cfg.Log.Infof("removing lazy peer")
|
||||
|
||||
m.inactivityManager.RemovePeer(cfg.PublicKey)
|
||||
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
|
||||
im.Stop()
|
||||
delete(m.inactivityMonitors, cfg.PeerConnID)
|
||||
cfg.Log.Debugf("inactivity monitor stopped")
|
||||
}
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
delete(m.managedPeers, peerID)
|
||||
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
||||
@@ -451,8 +451,12 @@ func (m *Manager) close() {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
m.connStateDispatcher.RemoveListener(m.connStateListener)
|
||||
m.activityManager.Close()
|
||||
|
||||
for _, iw := range m.inactivityMonitors {
|
||||
iw.Stop()
|
||||
}
|
||||
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
|
||||
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
||||
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
||||
|
||||
@@ -466,7 +470,7 @@ func (m *Manager) close() {
|
||||
}
|
||||
|
||||
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
|
||||
func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
|
||||
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
|
||||
m.routesMu.RLock()
|
||||
defer m.routesMu.RUnlock()
|
||||
|
||||
@@ -476,45 +480,38 @@ func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
|
||||
return true
|
||||
groupPeers := m.haGroupToPeers[haGroup]
|
||||
|
||||
for _, groupPeerID := range groupPeers {
|
||||
if groupPeerID == peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, ok := m.managedPeers[groupPeerID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if groupMp.expectedWatcher != watcherInactivity {
|
||||
continue
|
||||
}
|
||||
|
||||
// Other member is still connected, defer idle
|
||||
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
|
||||
groupPeers := m.haGroupToPeers[haGroup]
|
||||
for _, groupPeerID := range groupPeers {
|
||||
|
||||
if groupPeerID == peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, ok := m.managedPeers[groupPeerID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if groupMp.expectedWatcher != watcherInactivity {
|
||||
continue
|
||||
}
|
||||
|
||||
// If any peer in the group is active, do defer idle
|
||||
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -531,56 +528,100 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
|
||||
mp.peerCfg.Log.Infof("detected peer activity")
|
||||
|
||||
if !m.activateSinglePeer(mp.peerCfg, mp) {
|
||||
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
|
||||
return
|
||||
}
|
||||
|
||||
m.activateHAGroupPeers(mp.peerCfg)
|
||||
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
|
||||
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
||||
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
for peerID := range peerIDs {
|
||||
peerCfg, ok := m.managedPeers[peerID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by peerId: %v", peerID)
|
||||
continue
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||
return
|
||||
}
|
||||
|
||||
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
|
||||
iw, ok := m.inactivityMonitors[peerConnID]
|
||||
if ok {
|
||||
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
|
||||
iw.ResetMonitor(ctx, m.onInactive)
|
||||
} else {
|
||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
|
||||
continue
|
||||
}
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||
continue
|
||||
}
|
||||
// this is blocking operation, potentially can be optimized
|
||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||
|
||||
if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
|
||||
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
|
||||
continue
|
||||
}
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
// this is blocking operation, potentially can be optimized
|
||||
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
|
||||
// just in case free up
|
||||
m.inactivityMonitors[peerConnID].PauseTimer()
|
||||
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
|
||||
iw.PauseTimer()
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
|
||||
iw.ResetTimer()
|
||||
}
|
||||
|
||||
@@ -6,13 +6,9 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
LastActivities() map[string]monotime.Time
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "nb0"
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
@@ -105,6 +106,10 @@ type Conn struct {
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
connIDRelay nbnet.ConnectionID
|
||||
connIDICE nbnet.ConnectionID
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||
rosenpassRemoteKey []byte
|
||||
|
||||
@@ -112,9 +117,10 @@ type Conn struct {
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
wg sync.WaitGroup
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -130,17 +136,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
peerConnDispatcher: services.PeerConnDispatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -162,7 +169,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
@@ -219,7 +226,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
}
|
||||
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
func (conn *Conn) Close(signalToRemote bool) {
|
||||
func (conn *Conn) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
@@ -229,12 +236,6 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if signalToRemote {
|
||||
if err := conn.signaler.SignalIdle(conn.config.Key); err != nil {
|
||||
conn.Log.Errorf("failed to signal idle state to peer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn.Log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
@@ -262,6 +263,8 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
conn.freeUpConnID()
|
||||
|
||||
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
||||
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
@@ -286,6 +289,13 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||
}
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
|
||||
conn.onConnected = handler
|
||||
@@ -373,6 +383,10 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
ep = directEp
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
@@ -390,10 +404,15 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
oldState := conn.currentConnPriority
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
|
||||
if oldState == conntype.None {
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
@@ -431,6 +450,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
@@ -471,8 +491,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
|
||||
|
||||
conn.dumpState.NewLocalProxy()
|
||||
|
||||
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
@@ -485,6 +503,10 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
@@ -508,6 +530,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.Log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
@@ -522,7 +545,11 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -685,6 +712,36 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDRelay); err != nil {
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDRelay = ""
|
||||
}
|
||||
|
||||
if conn.connIDICE != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDICE); err != nil {
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDICE = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
|
||||
@@ -68,13 +68,3 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Signaler) SignalIdle(remoteKey string) error {
|
||||
return s.signal.Send(&sProto.Message{
|
||||
Key: s.wgPrivateKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey,
|
||||
Body: &sProto.Body{
|
||||
Type: sProto.Body_GO_IDLE,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ type RelayConnInfo struct {
|
||||
}
|
||||
|
||||
type WorkerRelay struct {
|
||||
peerCtx context.Context
|
||||
log *log.Entry
|
||||
isController bool
|
||||
config ConnConfig
|
||||
@@ -34,9 +33,8 @@ type WorkerRelay struct {
|
||||
wgWatcher *WGWatcher
|
||||
}
|
||||
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
||||
r := &WorkerRelay{
|
||||
peerCtx: ctx,
|
||||
log: log,
|
||||
isController: ctrl,
|
||||
config: config,
|
||||
@@ -64,7 +62,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
|
||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
|
||||
@@ -95,17 +95,6 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
||||
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnIdle(pubKey string) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.Close(true)
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnClose(pubKey string) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
@@ -114,7 +103,7 @@ func (s *Store) PeerConnClose(pubKey string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.Close(false)
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func (s *Store) PeersPubKey() []string {
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -552,16 +553,41 @@ func (w *Watcher) Stop() {
|
||||
w.currentChosenStatus = nil
|
||||
}
|
||||
|
||||
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||
switch handlerType(params.Route, params.UseNewDNSRoute) {
|
||||
func HandlerFromRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.WGIface,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
useNewDNSRoute bool,
|
||||
) RouteHandler {
|
||||
switch handlerType(rt, useNewDNSRoute) {
|
||||
case handlerTypeDnsInterceptor:
|
||||
return dnsinterceptor.New(params)
|
||||
return dnsinterceptor.New(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
statusRecorder,
|
||||
dnsServer,
|
||||
wgInterface,
|
||||
peerStore,
|
||||
)
|
||||
case handlerTypeDynamic:
|
||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||
return dynamic.NewRoute(params, dnsAddr)
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
dnsRouterInteval,
|
||||
statusRecorder,
|
||||
wgInterface,
|
||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||
)
|
||||
default:
|
||||
return static.NewRoute(params)
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
statuses map[route.ID]routerPeerStatus
|
||||
@@ -811,12 +811,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||
}
|
||||
|
||||
params := common.HandlerParams{
|
||||
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||
}
|
||||
// create new clientNetwork
|
||||
client := &Watcher{
|
||||
handler: static.NewRoute(params),
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type HandlerParams struct {
|
||||
Route *route.Route
|
||||
RouteRefCounter *refcounter.RouteRefCounter
|
||||
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
DnsRouterInterval time.Duration
|
||||
StatusRecorder *peer.Status
|
||||
WgInterface iface.WGIface
|
||||
DnsServer dns.Server
|
||||
PeerStore *peerstore.Store
|
||||
UseNewDNSRoute bool
|
||||
Firewall manager.Manager
|
||||
FakeIPManager *fakeip.Manager
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -13,14 +12,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -28,11 +24,6 @@ import (
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type internalDNATer interface {
|
||||
RemoveInternalDNATMapping(netip.Addr) error
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
@@ -49,22 +40,26 @@ type DnsInterceptor struct {
|
||||
interceptedDomains domainMap
|
||||
wgInterface wgInterface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
}
|
||||
|
||||
func New(params common.HandlerParams) *DnsInterceptor {
|
||||
func New(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
statusRecorder *peer.Status,
|
||||
dnsServer nbdns.Server,
|
||||
wgInterface wgInterface,
|
||||
peerStore *peerstore.Store,
|
||||
) *DnsInterceptor {
|
||||
return &DnsInterceptor{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
statusRecorder: params.StatusRecorder,
|
||||
dnsServer: params.DnsServer,
|
||||
wgInterface: params.WgInterface,
|
||||
peerStore: params.PeerStore,
|
||||
firewall: params.Firewall,
|
||||
fakeIPManager: params.FakeIPManager,
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
wgInterface: wgInterface,
|
||||
interceptedDomains: make(domainMap),
|
||||
peerStore: peerStore,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,13 +78,9 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// Routes should use fake IPs
|
||||
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||
}
|
||||
|
||||
// AllowedIPs should use real IPs
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
@@ -97,10 +88,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
}
|
||||
}
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
}
|
||||
|
||||
d.cleanupDNATMappings()
|
||||
|
||||
for _, domain := range d.route.Domains {
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
@@ -113,68 +102,6 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
|
||||
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
|
||||
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
|
||||
return realPrefix
|
||||
}
|
||||
|
||||
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
|
||||
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
|
||||
}
|
||||
|
||||
return realPrefix
|
||||
}
|
||||
|
||||
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
|
||||
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
|
||||
// AllowedIPs always use real IPs
|
||||
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
|
||||
}
|
||||
|
||||
if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
realPrefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
|
||||
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
|
||||
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
|
||||
routePrefix := d.transformRealToFakePrefix(realPrefix)
|
||||
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
|
||||
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
|
||||
}
|
||||
|
||||
// Add to AllowedIPs if we have a current peer (uses real IPs)
|
||||
if d.currentPeerKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
|
||||
}
|
||||
|
||||
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
|
||||
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
|
||||
if d.currentPeerKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowedIPs use real IPs
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
|
||||
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
@@ -182,9 +109,14 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// AllowedIPs use real IPs
|
||||
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,7 +132,6 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
var merr *multierror.Error
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// AllowedIPs use real IPs
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
@@ -356,8 +287,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,22 +297,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logPrefixChanges handles the logging for prefix changes
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
@@ -392,163 +305,70 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
var merr *multierror.Error
|
||||
var dnatMappings map[netip.Addr]netip.Addr
|
||||
|
||||
// Handle DNAT mappings for new prefixes
|
||||
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
|
||||
dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
for _, prefix := range toAdd {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||
dnatMappings[fakeIP] = realIP
|
||||
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
} else {
|
||||
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add new prefixes
|
||||
for _, prefix := range toAdd {
|
||||
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if d.currentPeerKey == "" {
|
||||
continue
|
||||
}
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
resolvedDomain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
d.addDNATMappings(dnatMappings)
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
for _, prefix := range toRemove {
|
||||
// Routes use fake IPs
|
||||
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||
}
|
||||
// AllowedIPs use real IPs
|
||||
if err := d.removeAllowedIP(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.removeDNATMappings(toRemove)
|
||||
}
|
||||
|
||||
// Update domain prefixes using resolved domain as key - store real IPs
|
||||
// Update domain prefixes using resolved domain as key
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
if d.route.KeepRoute {
|
||||
// replace stored prefixes with old + added
|
||||
// nolint:gocritic
|
||||
newPrefixes = append(oldPrefixes, toAdd...)
|
||||
}
|
||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||
|
||||
// Store real IPs for status (user-facing), not fake IPs
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||
if len(realPrefixes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefix := range realPrefixes {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
} else {
|
||||
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalDnatFw checks if the firewall supports internal DNAT
|
||||
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||
if d.firewall == nil || runtime.GOOS != "android" {
|
||||
return nil, false
|
||||
}
|
||||
fw, ok := d.firewall.(internalDNATer)
|
||||
return fw, ok
|
||||
}
|
||||
|
||||
// addDNATMappings adds DNAT mappings to the firewall
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for fakeIP, realIP := range mappings {
|
||||
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
} else {
|
||||
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupDNATMappings removes all DNAT mappings for this interceptor
|
||||
func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
d.removeDNATMappings(prefixes)
|
||||
}
|
||||
}
|
||||
|
||||
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Replace A and AAAA records with fake IPs
|
||||
for _, answer := range reply.Answer {
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
realIP, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.A = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
|
||||
case *dns.AAAA:
|
||||
realIP, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.AAAA = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
@@ -53,16 +52,24 @@ type Route struct {
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
interval: params.DnsRouterInterval,
|
||||
statusRecorder: params.StatusRecorder,
|
||||
wgInterface: params.WgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
nextIP netip.Addr // Next IP to allocate
|
||||
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||
}
|
||||
|
||||
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||
func NewManager() *Manager {
|
||||
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||
|
||||
return &Manager{
|
||||
nextIP: baseIP,
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: baseIP,
|
||||
maxIP: maxIP,
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||
// Returns the fake IP, or existing fake IP if already allocated
|
||||
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||
if !realIP.Is4() {
|
||||
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if fakeIP, exists := m.allocated[realIP]; exists {
|
||||
return fakeIP, nil
|
||||
}
|
||||
|
||||
startIP := m.nextIP
|
||||
for {
|
||||
currentIP := m.nextIP
|
||||
|
||||
// Advance to next IP, wrapping at boundary
|
||||
if m.nextIP.Compare(m.maxIP) >= 0 {
|
||||
m.nextIP = m.baseIP
|
||||
} else {
|
||||
m.nextIP = m.nextIP.Next()
|
||||
}
|
||||
|
||||
// Check if current IP is available
|
||||
if _, inUse := m.fakeToReal[currentIP]; !inUse {
|
||||
m.allocated[realIP] = currentIP
|
||||
m.fakeToReal[currentIP] = realIP
|
||||
return currentIP, nil
|
||||
}
|
||||
|
||||
// Prevent infinite loop if all IPs exhausted
|
||||
if m.nextIP.Compare(startIP) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
fakeIP, exists := m.allocated[realIP]
|
||||
return fakeIP, exists
|
||||
}
|
||||
|
||||
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
realIP, exists := m.fakeToReal[fakeIP]
|
||||
return realIP, exists
|
||||
}
|
||||
|
||||
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
||||
return netip.MustParsePrefix("240.0.0.0/8")
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
if manager.baseIP.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||
}
|
||||
|
||||
if manager.maxIP.String() != "240.255.255.254" {
|
||||
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||
}
|
||||
|
||||
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||
t.Errorf("Expected nextIP to start at baseIP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("8.8.8.8")
|
||||
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
|
||||
if !fakeIP.Is4() {
|
||||
t.Error("Fake IP should be IPv4")
|
||||
}
|
||||
|
||||
// Check it's in the correct range
|
||||
if fakeIP.As4()[0] != 240 {
|
||||
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||
}
|
||||
|
||||
// Should return same fake IP for same real IP
|
||||
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get existing fake IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP.Compare(fakeIP2) != 0 {
|
||||
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||
|
||||
_, err := manager.AllocateFakeIP(realIPv6)
|
||||
if err == nil {
|
||||
t.Error("Expected error for IPv6 address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Should not exist initially
|
||||
_, exists := manager.GetFakeIP(realIP)
|
||||
if exists {
|
||||
t.Error("Fake IP should not exist before allocation")
|
||||
}
|
||||
|
||||
// Allocate and check
|
||||
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate: %v", err)
|
||||
}
|
||||
|
||||
fakeIP, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
t.Error("Fake IP should exist after allocation")
|
||||
}
|
||||
|
||||
if fakeIP.Compare(expectedFakeIP) != 0 {
|
||||
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleAllocations(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
allocations := make(map[netip.Addr]netip.Addr)
|
||||
|
||||
// Allocate multiple IPs
|
||||
for i := 1; i <= 100; i++ {
|
||||
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
for _, existingFake := range allocations {
|
||||
if fakeIP.Compare(existingFake) == 0 {
|
||||
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
allocations[realIP] = fakeIP
|
||||
}
|
||||
|
||||
// Verify all allocations can be retrieved
|
||||
for realIP, expectedFake := range allocations {
|
||||
actualFake, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
t.Errorf("Missing allocation for %s", realIP.String())
|
||||
}
|
||||
if actualFake.Compare(expectedFake) != 0 {
|
||||
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFakeIPBlock(t *testing.T) {
|
||||
manager := NewManager()
|
||||
block := manager.GetFakeIPBlock()
|
||||
|
||||
expected := "240.0.0.0/8"
|
||||
if block.String() != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
const numGoroutines = 50
|
||||
const allocationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||
|
||||
// Concurrent allocations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < allocationsPerGoroutine; j++ {
|
||||
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
|
||||
return
|
||||
}
|
||||
results <- fakeIP
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Check for duplicates
|
||||
seen := make(map[netip.Addr]bool)
|
||||
count := 0
|
||||
for fakeIP := range results {
|
||||
if seen[fakeIP] {
|
||||
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
|
||||
}
|
||||
seen[fakeIP] = true
|
||||
count++
|
||||
}
|
||||
|
||||
if count != numGoroutines*allocationsPerGoroutine {
|
||||
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPExhaustion(t *testing.T) {
|
||||
// Create a manager with limited range for testing
|
||||
manager := &Manager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||
}
|
||||
|
||||
// Allocate all available IPs
|
||||
realIPs := []netip.Addr{
|
||||
netip.MustParseAddr("1.0.0.1"),
|
||||
netip.MustParseAddr("1.0.0.2"),
|
||||
netip.MustParseAddr("1.0.0.3"),
|
||||
}
|
||||
|
||||
for _, realIP := range realIPs {
|
||||
_, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to allocate one more - should fail
|
||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||
if err == nil {
|
||||
t.Error("Expected exhaustion error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapAround(t *testing.T) {
|
||||
// Create manager starting near the end of range
|
||||
manager := &Manager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
}
|
||||
|
||||
// Allocate the last IP
|
||||
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP1.String() != "240.0.0.254" {
|
||||
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||
}
|
||||
|
||||
// Next allocation should wrap around to the beginning
|
||||
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP2.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -26,8 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
@@ -44,7 +40,7 @@ import (
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() error
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||
TriggerSelection(route.HAMap)
|
||||
@@ -53,7 +49,7 @@ type Manager interface {
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
SetFirewall(firewall.Manager) error
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@@ -67,7 +63,6 @@ type ManagerConfig struct {
|
||||
InitialRoutes []*route.Route
|
||||
StateManager *statemanager.Manager
|
||||
DNSServer dns.Server
|
||||
DNSFeatureFlag bool
|
||||
PeerStore *peerstore.Store
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -94,13 +89,11 @@ type DefaultManager struct {
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
dnsServer dns.Server
|
||||
firewall firewall.Manager
|
||||
peerStore *peerstore.Store
|
||||
useNewDNSRoute bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
fakeIPManager *fakeip.Manager
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
@@ -136,31 +129,11 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
dm.setupAndroidRoutes(config)
|
||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||
dm.notifier.SetInitialClientRoutes(cr)
|
||||
}
|
||||
return dm
|
||||
}
|
||||
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
cr := m.initialClientRoutes(config.InitialRoutes)
|
||||
|
||||
routesForComparison := slices.Clone(cr)
|
||||
|
||||
if config.DNSFeatureFlag {
|
||||
m.fakeIPManager = fakeip.NewManager()
|
||||
|
||||
id := uuid.NewString()
|
||||
fakeIPRoute := &route.Route{
|
||||
ID: route.ID(id),
|
||||
Network: m.fakeIPManager.GetFakeIPBlock(),
|
||||
NetID: route.NetID(id),
|
||||
Peer: m.pubKey,
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
cr = append(cr, fakeIPRoute)
|
||||
}
|
||||
|
||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||
m.routeRefCounter = refcounter.New(
|
||||
@@ -201,11 +174,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() error {
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
m.routeSelector = m.initSelector()
|
||||
|
||||
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||
@@ -219,12 +192,13 @@ func (m *DefaultManager) Init() error {
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Routing setup complete")
|
||||
return nil
|
||||
return beforePeerHook, afterPeerHook, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
@@ -248,16 +222,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall manager for the DefaultManager
|
||||
// Not thread-safe, should be called before starting the manager
|
||||
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
m.firewall = firewall
|
||||
|
||||
if m.disableServerRoutes || firewall == nil {
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
if m.disableServerRoutes {
|
||||
log.Info("server routes are disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if firewall == nil {
|
||||
return errors.New("firewall manager is not set")
|
||||
}
|
||||
|
||||
var err error
|
||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
if err != nil {
|
||||
@@ -325,20 +299,17 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
}
|
||||
|
||||
for id, route := range toAdd {
|
||||
params := common.HandlerParams{
|
||||
Route: route,
|
||||
RouteRefCounter: m.routeRefCounter,
|
||||
AllowedIPsRefCounter: m.allowedIPsRefCounter,
|
||||
DnsRouterInterval: m.dnsRouteInterval,
|
||||
StatusRecorder: m.statusRecorder,
|
||||
WgInterface: m.wgInterface,
|
||||
DnsServer: m.dnsServer,
|
||||
PeerStore: m.peerStore,
|
||||
UseNewDNSRoute: m.useNewDNSRoute,
|
||||
Firewall: m.firewall,
|
||||
FakeIPManager: m.fakeIPManager,
|
||||
}
|
||||
handler := client.HandlerFromRoute(params)
|
||||
handler := client.HandlerFromRoute(
|
||||
route,
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsRouteInterval,
|
||||
m.statusRecorder,
|
||||
m.wgInterface,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
m.useNewDNSRoute,
|
||||
)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||
continue
|
||||
@@ -546,7 +517,6 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
|
||||
@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
StatusRecorder: statusRecorder,
|
||||
})
|
||||
|
||||
err = routeManager.Init()
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop(nil)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
@@ -22,8 +23,8 @@ type MockManager struct {
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() error {
|
||||
return nil
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||
@@ -86,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
||||
|
||||
}
|
||||
|
||||
func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
|
||||
124
client/internal/routemanager/notifier/notifier.go
Normal file
124
client/internal/routemanager/notifier/notifier.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
initialRouteRanges []string
|
||||
routeRanges []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
nets = append(nets, r.Network.String())
|
||||
}
|
||||
sort.Strings(nets)
|
||||
n.initialRouteRanges = nets
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
if runtime.GOOS != "android" {
|
||||
return
|
||||
}
|
||||
|
||||
var newNets []string
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
newNets = append(newNets, r.Network.String())
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
// OnNewPrefixes is called from iOS only
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *Notifier) hasDiff(a []string, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
||||
}
|
||||
|
||||
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
|
||||
func addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||
ranges := inputRanges
|
||||
for _, r := range inputRanges {
|
||||
// we are intentionally adding the ipv6 default range in case of ipv4 default range
|
||||
// to ensure that all traffic is managed by the tunnel interface on android
|
||||
if r == "0.0.0.0/0" {
|
||||
ranges = append(ranges, "::/0")
|
||||
break
|
||||
}
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||
// initialRoutes contains fake IP block for interface configuration
|
||||
filteredInitial := make([]*route.Route, 0)
|
||||
for _, r := range initialRoutes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
filteredInitial = append(filteredInitial, r)
|
||||
}
|
||||
n.initialRoutes = filteredInitial
|
||||
|
||||
// routesForComparison excludes fake IP block for comparison with new routes
|
||||
filteredComparison := make([]*route.Route, 0)
|
||||
for _, r := range routesForComparison {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
filteredComparison = append(filteredComparison, r)
|
||||
}
|
||||
n.currentRoutes = filteredComparison
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
var newRoutes []*route.Route
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
newRoutes = append(newRoutes, r)
|
||||
}
|
||||
}
|
||||
|
||||
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
n.currentRoutes = newRoutes
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
|
||||
// Not used on Android
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
routeStrings := n.routesToStrings(n.currentRoutes)
|
||||
sort.Strings(routeStrings)
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
||||
nets := make([]string, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
nets = append(nets, r.NetString())
|
||||
}
|
||||
return nets
|
||||
}
|
||||
|
||||
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
|
||||
slices.SortFunc(a, func(x, y *route.Route) int {
|
||||
return strings.Compare(x.NetString(), y.NetString())
|
||||
})
|
||||
slices.SortFunc(b, func(x, y *route.Route) int {
|
||||
return strings.Compare(x.NetString(), y.NetString())
|
||||
})
|
||||
|
||||
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
|
||||
return x.NetString() == y.NetString()
|
||||
})
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
initialStrings := n.routesToStrings(n.initialRoutes)
|
||||
sort.Strings(initialStrings)
|
||||
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
|
||||
}
|
||||
|
||||
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
|
||||
for _, r := range routes {
|
||||
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
|
||||
return append(slices.Clone(inputRanges), "::/0")
|
||||
}
|
||||
}
|
||||
return inputRanges
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
currentPrefixes []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// iOS doesn't care about initial routes
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
|
||||
if slices.Equal(n.currentPrefixes, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||
for _, r := range inputRanges {
|
||||
if r == "0.0.0.0/0" {
|
||||
return append(slices.Clone(inputRanges), "::/0")
|
||||
}
|
||||
}
|
||||
return inputRanges
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct{}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return []string{}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -17,11 +16,11 @@ type Route struct {
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(params common.HandlerParams) *Route {
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
return &Route{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
@@ -54,13 +52,6 @@ type SysOps struct {
|
||||
mu sync.Mutex
|
||||
// notifier is used to notify the system of route changes (also used on mobile)
|
||||
notifier *notifier.Notifier
|
||||
// seq is an atomic counter for generating unique sequence numbers for route messages
|
||||
//nolint:unused // only used on BSD systems
|
||||
seq atomic.Uint32
|
||||
|
||||
localSubnetsCache []*net.IPNet
|
||||
localSubnetsCacheMu sync.RWMutex
|
||||
localSubnetsCacheTime time.Time
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
@@ -70,11 +61,6 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unused // only used on BSD systems
|
||||
func (r *SysOps) getSeq() int {
|
||||
return int(r.seq.Add(1))
|
||||
}
|
||||
|
||||
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
return nil
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
@@ -25,8 +24,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const localSubnetsCacheTTL = 15 * time.Minute
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
@@ -34,7 +31,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
@@ -78,10 +75,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
if err := r.setupHooks(initAddresses, stateManager); err != nil {
|
||||
return fmt.Errorf("setup hooks: %w", err)
|
||||
}
|
||||
return nil
|
||||
return r.setupHooks(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
// updateState updates state on every change so it will be persisted regularly
|
||||
@@ -134,14 +128,18 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
|
||||
exitNextHop := nexthop
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
||||
exitNextHop := Nexthop{
|
||||
IP: nexthop.IP,
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr := vpnIntf.Address().IP
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
||||
|
||||
exitNextHop = initialNextHop
|
||||
}
|
||||
|
||||
@@ -154,37 +152,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
}
|
||||
|
||||
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
||||
r.localSubnetsCacheMu.RLock()
|
||||
cacheAge := time.Since(r.localSubnetsCacheTime)
|
||||
subnets := r.localSubnetsCache
|
||||
r.localSubnetsCacheMu.RUnlock()
|
||||
|
||||
if cacheAge > localSubnetsCacheTTL || subnets == nil {
|
||||
r.localSubnetsCacheMu.Lock()
|
||||
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
|
||||
r.refreshLocalSubnetsCache()
|
||||
}
|
||||
subnets = r.localSubnetsCache
|
||||
r.localSubnetsCacheMu.Unlock()
|
||||
}
|
||||
|
||||
for _, subnet := range subnets {
|
||||
if subnet.Contains(prefix.Addr().AsSlice()) {
|
||||
return true, subnet
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) refreshLocalSubnetsCache() {
|
||||
localInterfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get local interfaces: %v", err)
|
||||
return
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var newSubnets []*net.IPNet
|
||||
for _, intf := range localInterfaces {
|
||||
addrs, err := intf.Addrs()
|
||||
if err != nil {
|
||||
@@ -198,12 +171,14 @@ func (r *SysOps) refreshLocalSubnetsCache() {
|
||||
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
||||
continue
|
||||
}
|
||||
newSubnets = append(newSubnets, ipnet)
|
||||
|
||||
if ipnet.Contains(prefix.Addr().AsSlice()) {
|
||||
return true, ipnet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.localSubnetsCache = newSubnets
|
||||
r.localSubnetsCacheTime = time.Now()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
@@ -289,7 +264,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
@@ -314,11 +289,9 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,11 +300,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
@@ -346,16 +319,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
|
||||
@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := addRule(rule); err != nil {
|
||||
return fmt.Errorf("%s: %w", rule.description, err)
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
|
||||
@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "nb0",
|
||||
name: "wt0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
|
||||
@@ -18,9 +18,10 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
@@ -107,7 +108,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const InfiniteLifetime = 0xffffffff
|
||||
@@ -136,7 +137,7 @@ const (
|
||||
RouteDeleted
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user