mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 05:56:27 -04:00
Compare commits
71 Commits
feature/bu
...
apply-rout
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11d1ca6fbc | ||
|
|
9424b88db2 | ||
|
|
609654eee7 | ||
|
|
b604c66140 | ||
|
|
ea4d13e96d | ||
|
|
87148c503f | ||
|
|
0cd36baf67 | ||
|
|
06980e7fa0 | ||
|
|
1ce4ee0cef | ||
|
|
f367925496 | ||
|
|
1aa4ec74d4 | ||
|
|
616b19c064 | ||
|
|
af27aaf9af | ||
|
|
35287f8241 | ||
|
|
07b220d91b | ||
|
|
7227979514 | ||
|
|
fdb155f34b | ||
|
|
ef5027ab2a | ||
|
|
a29b28390b | ||
|
|
aa2662a2bb | ||
|
|
3ecab1f31e | ||
|
|
8be95a2efb | ||
|
|
2a3623d6ac | ||
|
|
fbdccbc2a1 | ||
|
|
855d21c37e | ||
|
|
db9facf9cb | ||
|
|
41cd4952f1 | ||
|
|
dfe7f91ddd | ||
|
|
bc71beca97 | ||
|
|
d6444e14e4 | ||
|
|
ae01335bfe | ||
|
|
8208a7939c | ||
|
|
f16f0c7831 | ||
|
|
aa07b3b87b | ||
|
|
2bef214cc0 | ||
|
|
cfb2d82352 | ||
|
|
684501fd35 | ||
|
|
0492c1724a | ||
|
|
6f436e57b5 | ||
|
|
a0d28f9851 | ||
|
|
cdd27a9fe5 | ||
|
|
5523040acd | ||
|
|
670446d42e | ||
|
|
5bed6777d5 | ||
|
|
a0482ebc7b | ||
|
|
2a89d6e47a | ||
|
|
24f932b2ce | ||
|
|
c03435061c | ||
|
|
8e948739f1 | ||
|
|
9b53cad752 | ||
|
|
802a18167c | ||
|
|
e9108ffe6c | ||
|
|
e806d9de38 | ||
|
|
daa8380df9 | ||
|
|
4785f23fc4 | ||
|
|
1d4cfb83e7 | ||
|
|
207fa059d2 | ||
|
|
cbcdad7814 | ||
|
|
701c13807a | ||
|
|
99f8dc7748 | ||
|
|
f1de8e6eb0 | ||
|
|
b2a10780af | ||
|
|
43ae79d848 | ||
|
|
e520b64c6d | ||
|
|
92c91bbdd8 | ||
|
|
adf494e1ac | ||
|
|
2158461121 | ||
|
|
0cd4b601c3 | ||
|
|
ee1cec47b3 | ||
|
|
efb0edfc4c | ||
|
|
20f59ddecb |
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,16 +37,21 @@ If yes, which one?
|
||||
|
||||
**Debug output**
|
||||
|
||||
To help us resolve the problem, please attach the following debug output
|
||||
To help us resolve the problem, please attach the following anonymized status output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
As well as the file created by
|
||||
Create and upload a debug bundle, and share the returned file key:
|
||||
|
||||
netbird debug for 1m -AS -U
|
||||
|
||||
*Uploaded files are automatically deleted after 30 days.*
|
||||
|
||||
|
||||
Alternatively, create the file only and attach it here manually:
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
We advise reviewing the anonymized output for any remaining personal information.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -13,3 +13,5 @@
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] Extended the README / documentation, if necessary
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
8
.github/workflows/golang-test-linux.yml
vendored
8
.github/workflows/golang-test-linux.yml
vendored
@@ -223,6 +223,10 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -269,6 +273,10 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
|
||||
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,6 @@ jobs:
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -179,6 +179,7 @@ jobs:
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# For details on buf.gen.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-gen-yaml/
|
||||
version: v2
|
||||
plugins:
|
||||
- remote: buf.build/protocolbuffers/go:v1.35.1
|
||||
out: .
|
||||
- remote: buf.build/grpc/go:v1.5.1
|
||||
out: .
|
||||
10
buf.yaml
10
buf.yaml
@@ -1,10 +0,0 @@
|
||||
# For details on buf.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-yaml
|
||||
version: v2
|
||||
modules:
|
||||
- path: proto
|
||||
lint:
|
||||
use:
|
||||
- BASIC
|
||||
breaking:
|
||||
use:
|
||||
- FILE
|
||||
@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
return a.ipAnonymizer[ip]
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||
// Convert IP to netip.Addr
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return addr
|
||||
}
|
||||
|
||||
anonIP := a.AnonymizeIP(ip)
|
||||
|
||||
return net.UDPAddr{
|
||||
IP: anonIP.AsSlice(),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
}
|
||||
|
||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -98,11 +99,11 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
}
|
||||
}
|
||||
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
@@ -26,22 +26,22 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -77,9 +77,9 @@ var (
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
debugUploadBundle bool
|
||||
debugUploadBundleURL string
|
||||
lazyConnEnabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -184,6 +184,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
|
||||
@@ -30,7 +30,7 @@ func newSVCConfig() *service.Config {
|
||||
return &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Description: "Netbird mesh network client",
|
||||
Option: make(service.KeyValue),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
if logFile != "console" {
|
||||
if logFile != "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ func init() {
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
case "", "idle", "connecting", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
|
||||
@@ -6,6 +6,8 @@ const (
|
||||
disableServerRoutesFlag = "disable-server-routes"
|
||||
disableDNSFlag = "disable-dns"
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
blockInboundFlag = "block-inbound"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -13,6 +15,8 @@ var (
|
||||
disableServerRoutes bool
|
||||
disableDNS bool
|
||||
disableFirewall bool
|
||||
blockLANAccess bool
|
||||
blockInbound bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -28,4 +32,11 @@ func init() {
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
|
||||
280
client/cmd/up.go
280
client/cmd/up.go
@@ -55,12 +55,11 @@ func init() {
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
@@ -119,79 +118,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup config: %v", err)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
@@ -199,7 +128,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
config, err := internal.UpdateOrCreateConfig(*ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
@@ -258,21 +187,153 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get setup key: %v", err)
|
||||
}
|
||||
|
||||
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -297,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
loginRequest.InterfaceName = &interfaceName
|
||||
}
|
||||
@@ -332,45 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
loginRequest.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
return &loginRequest, nil
|
||||
}
|
||||
|
||||
func validateNATExternalIPs(list []string) error {
|
||||
|
||||
@@ -147,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -198,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
@@ -219,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -248,10 +248,6 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
@@ -278,10 +274,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
|
||||
@@ -116,6 +116,8 @@ type Manager interface {
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
|
||||
@@ -170,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -324,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Data: ip.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
|
||||
@@ -573,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -1006,10 +1002,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ type Forwarder struct {
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
@@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
@@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
|
||||
}
|
||||
|
||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
|
||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return value.([]byte), true
|
||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||
|
||||
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
|
||||
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
|
||||
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
wg.Wait()
|
||||
|
||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
|
||||
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
|
||||
}
|
||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
|
||||
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
|
||||
@@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
}
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
|
||||
@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
|
||||
@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -39,8 +39,12 @@ const (
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
||||
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
||||
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
@@ -71,7 +75,6 @@ type Manager struct {
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -148,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
@@ -269,7 +277,7 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||
case !m.netstack && m.nativeFirewall != nil:
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
@@ -326,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return m.stateful
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
@@ -606,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
}
|
||||
// for netflow we keep track even if the firewall is stateless
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -777,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack && m.localForwarding {
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
// track inbound packets to get the correct direction and session id for flows
|
||||
@@ -789,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
fwd := m.forwarder.Load()
|
||||
if fwd == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
@@ -1088,11 +1099,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
|
||||
@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Apply scenario-specific setup
|
||||
sc.setupFunc(manager)
|
||||
|
||||
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Pre-populate connection table
|
||||
srcIPs := generateRandomIPs(count)
|
||||
dstIPs := generateRandomIPs(count)
|
||||
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "post_handshake",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -19,12 +19,8 @@ import (
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
localIP := net.ParseIP("100.10.0.100")
|
||||
wgNet := &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
@@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
localIP, wgNet, err := net.ParseCIDR(network)
|
||||
require.NoError(tb, err)
|
||||
wgNet := netip.MustParsePrefix(network)
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
@@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
@@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer func() {
|
||||
@@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.decoders = sync.Pool{
|
||||
|
||||
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var zeroKey wgtypes.Key
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
}
|
||||
@@ -201,14 +203,71 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
func (c *KernelConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
||||
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
||||
return nil, fmt.Errorf("wgctl: %w", err)
|
||||
}
|
||||
return WGStats{
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
TxBytes: peer.TransmitBytes,
|
||||
RxBytes: peer.ReceiveBytes,
|
||||
}, nil
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("Got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||
}
|
||||
fullStats := &Stats{
|
||||
DeviceName: wgDevice.Name,
|
||||
PublicKey: wgDevice.PublicKey.String(),
|
||||
ListenPort: wgDevice.ListenPort,
|
||||
FWMark: wgDevice.FirewallMark,
|
||||
Peers: []Peer{},
|
||||
}
|
||||
|
||||
for _, p := range wgDevice.Peers {
|
||||
peer := Peer{
|
||||
PublicKey: p.PublicKey.String(),
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
TxBytes: p.TransmitBytes,
|
||||
RxBytes: p.ReceiveBytes,
|
||||
LastHandshake: p.LastHandshakeTime,
|
||||
PresharedKey: p.PresharedKey != zeroKey,
|
||||
}
|
||||
if p.Endpoint != nil {
|
||||
peer.Endpoint = *p.Endpoint
|
||||
}
|
||||
fullStats.Peers = append(fullStats.Peers, peer)
|
||||
}
|
||||
return fullStats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wgctl: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("Got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||
}
|
||||
|
||||
for _, peer := range wgDevice.Peers {
|
||||
stats[peer.PublicKey.String()] = WGStats{
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
TxBytes: peer.TransmitBytes,
|
||||
RxBytes: peer.ReceiveBytes,
|
||||
}
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -17,6 +18,20 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
)
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
|
||||
type WGUSPConfigurer struct {
|
||||
@@ -178,6 +193,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
||||
}
|
||||
|
||||
return parseStatus(c.deviceName, ipcStr)
|
||||
}
|
||||
|
||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||
func (t *WGUSPConfigurer) startUAPI() {
|
||||
var err error
|
||||
@@ -217,91 +241,75 @@ func (t *WGUSPConfigurer) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
ipc, err := t.device.IpcGet()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
||||
return nil, fmt.Errorf("ipc get: %w", err)
|
||||
}
|
||||
|
||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
||||
"last_handshake_time_sec",
|
||||
"last_handshake_time_nsec",
|
||||
"tx_bytes",
|
||||
"rx_bytes",
|
||||
})
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
||||
}
|
||||
|
||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
||||
}
|
||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
||||
}
|
||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
||||
}
|
||||
|
||||
return WGStats{
|
||||
LastHandshake: time.Unix(sec, nsec),
|
||||
TxBytes: txBytes,
|
||||
RxBytes: rxBytes,
|
||||
}, nil
|
||||
return parseTransfers(ipc)
|
||||
}
|
||||
|
||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
|
||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
||||
|
||||
lines := strings.Split(ipcInput, "\n")
|
||||
|
||||
configFound := map[string]string{}
|
||||
foundPeer := false
|
||||
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
var (
|
||||
currentKey string
|
||||
currentStats WGStats
|
||||
hasPeer bool
|
||||
)
|
||||
lines := strings.Split(ipc, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// If we're within the details of the found peer and encounter another public key,
|
||||
// this means we're starting another peer's details. So, stop.
|
||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
||||
break
|
||||
}
|
||||
|
||||
// Identify the peer with the specific public key
|
||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
||||
foundPeer = true
|
||||
}
|
||||
|
||||
for _, key := range searchConfigKeys {
|
||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
||||
v := strings.SplitN(line, "=", 2)
|
||||
configFound[v[0]] = v[1]
|
||||
if strings.HasPrefix(line, "public_key=") {
|
||||
peerID := strings.TrimPrefix(line, "public_key=")
|
||||
h, err := hex.DecodeString(peerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||
}
|
||||
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||
currentStats = WGStats{} // Reset stats for the new peer
|
||||
hasPeer = true
|
||||
stats[currentKey] = currentStats
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasPeer {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.SplitN(line, "=", 2)
|
||||
if len(key) != 2 {
|
||||
continue
|
||||
}
|
||||
switch key[0] {
|
||||
case ipcKeyLastHandshakeTimeSec:
|
||||
hs, err := toLastHandshake(key[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentStats.LastHandshake = hs
|
||||
stats[currentKey] = currentStats
|
||||
case ipcKeyRxBytes:
|
||||
rxBytes, err := toBytes(key[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||
}
|
||||
currentStats.RxBytes = rxBytes
|
||||
stats[currentKey] = currentStats
|
||||
case ipcKeyTxBytes:
|
||||
TxBytes, err := toBytes(key[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||
}
|
||||
currentStats.TxBytes = TxBytes
|
||||
stats[currentKey] = currentStats
|
||||
}
|
||||
}
|
||||
|
||||
// todo: use multierr
|
||||
for _, key := range searchConfigKeys {
|
||||
if _, ok := configFound[key]; !ok {
|
||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
||||
}
|
||||
}
|
||||
if !foundPeer {
|
||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
||||
}
|
||||
|
||||
return configFound, nil
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
@@ -355,9 +363,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
return time.Unix(sec, 0), nil
|
||||
}
|
||||
|
||||
func toBytes(s string) (int64, error) {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
||||
// Decode hex string to bytes
|
||||
keyBytes, err := hex.DecodeString(hexKey)
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
||||
}
|
||||
|
||||
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
||||
if len(keyBytes) != 32 {
|
||||
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
||||
}
|
||||
|
||||
// Convert to wgtypes.Key
|
||||
var key wgtypes.Key
|
||||
copy(key[:], keyBytes)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
stats := &Stats{DeviceName: deviceName}
|
||||
var currentPeer *Peer
|
||||
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
key := parts[0]
|
||||
val := parts[1]
|
||||
|
||||
switch key {
|
||||
case privateKey:
|
||||
key, err := hexToWireguardKey(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse private key: %v", err)
|
||||
continue
|
||||
}
|
||||
stats.PublicKey = key.PublicKey().String()
|
||||
case publicKey:
|
||||
// Save previous peer
|
||||
if currentPeer != nil {
|
||||
stats.Peers = append(stats.Peers, *currentPeer)
|
||||
}
|
||||
key, err := hexToWireguardKey(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse public key: %v", err)
|
||||
continue
|
||||
}
|
||||
currentPeer = &Peer{
|
||||
PublicKey: key.String(),
|
||||
}
|
||||
case listenPort:
|
||||
if port, err := strconv.Atoi(val); err == nil {
|
||||
stats.ListenPort = port
|
||||
}
|
||||
case fwmark:
|
||||
if fwmark, err := strconv.Atoi(val); err == nil {
|
||||
stats.FWMark = fwmark
|
||||
}
|
||||
case endpoint:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint port: %v", err)
|
||||
continue
|
||||
}
|
||||
currentPeer.Endpoint = net.UDPAddr{
|
||||
IP: net.ParseIP(host),
|
||||
Port: port,
|
||||
}
|
||||
case allowedIP:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(val)
|
||||
if err == nil {
|
||||
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
||||
}
|
||||
case ipcKeyTxBytes:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
rxBytes, err := toBytes(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.TxBytes = rxBytes
|
||||
case ipcKeyRxBytes:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
rxBytes, err := toBytes(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.RxBytes = rxBytes
|
||||
|
||||
case ipcKeyLastHandshakeTimeSec:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ts, err := toLastHandshake(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.LastHandshake = ts
|
||||
case presharedKey:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
if val != "" {
|
||||
currentPeer.PresharedKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentPeer != nil {
|
||||
stats.Peers = append(stats.Peers, *currentPeer)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -2,10 +2,8 @@ package configurer
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -34,58 +32,35 @@ errno=0
|
||||
|
||||
`
|
||||
|
||||
func Test_findPeerInfo(t *testing.T) {
|
||||
func Test_parseTransfers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerKey string
|
||||
searchKeys []string
|
||||
want map[string]string
|
||||
wantErr bool
|
||||
name string
|
||||
peerKey string
|
||||
want WGStats
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
searchKeys: []string{"tx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "38333",
|
||||
name: "single",
|
||||
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||
want: WGStats{
|
||||
TxBytes: 0,
|
||||
RxBytes: 0,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "38333",
|
||||
"rx_bytes": "2224",
|
||||
name: "multiple",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
want: WGStats{
|
||||
TxBytes: 38333,
|
||||
RxBytes: 2224,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "lastpeer",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "1212111",
|
||||
"rx_bytes": "1929999999",
|
||||
name: "lastpeer",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
want: WGStats{
|
||||
TxBytes: 1212111,
|
||||
RxBytes: 1929999999,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "peer not found",
|
||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
||||
searchKeys: nil,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key not found",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "1212111",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
||||
key, err := wgtypes.NewKey(res)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
||||
stats, err := parseTransfers(ipcFixture)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
stat, ok := stats[key.String()]
|
||||
if !ok {
|
||||
require.True(t, ok)
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.want, stat)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
24
client/iface/configurer/wgshow.go
Normal file
24
client/iface/configurer/wgshow.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string
|
||||
Endpoint net.UDPAddr
|
||||
AllowedIPs []net.IPNet
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
LastHandshake time.Time
|
||||
PresharedKey bool
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
DeviceName string
|
||||
PublicKey string
|
||||
ListenPort int
|
||||
FWMark int
|
||||
Peers []Peer
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -24,9 +23,6 @@ type PacketFilter interface {
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
SetNetwork(*net.IPNet)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create nbnetstack tun interface")
|
||||
|
||||
// TODO: get from service listener runtime IP
|
||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("last ip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
|
||||
@@ -16,5 +16,6 @@ type WGConfigurer interface {
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close()
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
}
|
||||
|
||||
ip := address.IP.String()
|
||||
mask := "0x" + address.Network.Mask.String()
|
||||
|
||||
// Convert prefix length to hex netmask
|
||||
prefixLen := address.Network.Bits()
|
||||
if !address.IP.Is4() {
|
||||
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||
}
|
||||
|
||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||
|
||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||
|
||||
|
||||
@@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
||||
}
|
||||
|
||||
w.filter = filter
|
||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
||||
|
||||
w.tun.FilteredDevice().SetFilter(filter)
|
||||
return nil
|
||||
@@ -212,9 +211,13 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
return w.tun.Device()
|
||||
}
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return w.configurer.GetStats(peerKey)
|
||||
// GetStats returns the last handshake time, rx and tx bytes
|
||||
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return w.configurer.GetStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||
return w.configurer.FullStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) waitUntilRemoved() error {
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -15,8 +13,8 @@ import (
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
address net.IP
|
||||
dnsAddress net.IP
|
||||
address netip.Addr
|
||||
dnsAddress netip.Addr
|
||||
mtu int
|
||||
listenAddress string
|
||||
|
||||
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
address: address,
|
||||
dnsAddress: dnsAddress,
|
||||
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
addr, ok := netip.AddrFromSlice(t.address)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||
}
|
||||
|
||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||
}
|
||||
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{addr.Unmap()},
|
||||
[]netip.Addr{dnsAddr.Unmap()},
|
||||
[]netip.Addr{t.address},
|
||||
[]netip.Addr{t.dnsAddress},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -2,28 +2,27 @@ package wgaddr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
IP netip.Addr
|
||||
Network netip.Prefix
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (Address, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
prefix, err := netip.ParsePrefix(address)
|
||||
if err != nil {
|
||||
return Address{}, err
|
||||
}
|
||||
return Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
IP: prefix.Addr().Unmap(),
|
||||
Network: prefix.Masked(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
|
||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||
|
||||
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
||||
|
||||
Unicode True
|
||||
|
||||
######################################################################
|
||||
@@ -49,6 +51,10 @@ ShowInstDetails Show
|
||||
|
||||
######################################################################
|
||||
|
||||
!include "MUI2.nsh"
|
||||
!include LogicLib.nsh
|
||||
!include "nsDialogs.nsh"
|
||||
|
||||
!define MUI_ICON "${ICON}"
|
||||
!define MUI_UNICON "${ICON}"
|
||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||
@@ -58,9 +64,6 @@ ShowInstDetails Show
|
||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||
######################################################################
|
||||
|
||||
!include "MUI2.nsh"
|
||||
!include LogicLib.nsh
|
||||
|
||||
!define MUI_ABORTWARNING
|
||||
!define MUI_UNABORTWARNING
|
||||
|
||||
@@ -70,13 +73,16 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_PAGE_DIRECTORY
|
||||
|
||||
; Custom page for autostart checkbox
|
||||
Page custom AutostartPage AutostartPageLeave
|
||||
|
||||
!insertmacro MUI_PAGE_INSTFILES
|
||||
|
||||
!insertmacro MUI_PAGE_FINISH
|
||||
|
||||
!insertmacro MUI_UNPAGE_WELCOME
|
||||
|
||||
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
||||
|
||||
!insertmacro MUI_UNPAGE_CONFIRM
|
||||
|
||||
!insertmacro MUI_UNPAGE_INSTFILES
|
||||
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
|
||||
Var AutostartCheckbox
|
||||
Var AutostartEnabled
|
||||
|
||||
; Variables for uninstall data deletion option
|
||||
Var DeleteDataCheckbox
|
||||
Var DeleteDataEnabled
|
||||
|
||||
######################################################################
|
||||
|
||||
; Function to create the autostart options page
|
||||
@@ -104,8 +114,8 @@ Function AutostartPage
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||
Pop $AutostartCheckbox
|
||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||
${NSD_Check} $AutostartCheckbox
|
||||
StrCpy $AutostartEnabled "1"
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
@@ -115,6 +125,30 @@ Function AutostartPageLeave
|
||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||
FunctionEnd
|
||||
|
||||
; Function to create the uninstall data deletion page
|
||||
Function un.DeleteDataPage
|
||||
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
||||
|
||||
nsDialogs::Create 1018
|
||||
Pop $0
|
||||
|
||||
${If} $0 == error
|
||||
Abort
|
||||
${EndIf}
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
||||
Pop $DeleteDataCheckbox
|
||||
${NSD_Uncheck} $DeleteDataCheckbox
|
||||
StrCpy $DeleteDataEnabled "0"
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
|
||||
; Function to handle leaving the data deletion page
|
||||
Function un.DeleteDataPageLeave
|
||||
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
||||
FunctionEnd
|
||||
|
||||
Function GetAppFromCommand
|
||||
Exch $1
|
||||
Push $2
|
||||
@@ -176,10 +210,10 @@ ${EndIf}
|
||||
FunctionEnd
|
||||
######################################################################
|
||||
Section -MainProgram
|
||||
${INSTALL_TYPE}
|
||||
# SetOverwrite ifnewer
|
||||
SetOutPath "$INSTDIR"
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
${INSTALL_TYPE}
|
||||
# SetOverwrite ifnewer
|
||||
SetOutPath "$INSTDIR"
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
SectionEnd
|
||||
######################################################################
|
||||
|
||||
@@ -225,31 +259,58 @@ SectionEnd
|
||||
Section Uninstall
|
||||
${INSTALL_TYPE}
|
||||
|
||||
DetailPrint "Stopping Netbird service..."
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||
DetailPrint "Uninstalling Netbird service..."
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
|
||||
# kill ui client
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
${If} $DeleteDataEnabled == "1"
|
||||
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
||||
ClearErrors
|
||||
RMDir /r "${NETBIRD_DATA_DIR}"
|
||||
IfErrors 0 +2 ; If no errors, jump over the message
|
||||
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
||||
DetailPrint "Netbird data directory removal complete."
|
||||
${Else}
|
||||
DetailPrint "User did not opt to delete Netbird data."
|
||||
${EndIf}
|
||||
|
||||
# wait the service uninstall take unblock the executable
|
||||
DetailPrint "Waiting for service handle to be released..."
|
||||
Sleep 3000
|
||||
|
||||
DetailPrint "Deleting application files..."
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
DetailPrint "Removing application directory..."
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
DetailPrint "Removing shortcuts..."
|
||||
SetShellVarContext all
|
||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||
|
||||
DetailPrint "Removing registry keys..."
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||
|
||||
DetailPrint "Removing application directory from PATH..."
|
||||
EnVar::SetHKLM
|
||||
EnVar::DeleteValue "path" "$INSTDIR"
|
||||
|
||||
DetailPrint "Uninstallation finished."
|
||||
SectionEnd
|
||||
|
||||
|
||||
|
||||
@@ -58,6 +58,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
if d.firewall == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
@@ -69,20 +74,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
if d.firewall == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
|
||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
}
|
||||
@@ -291,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
// return traffic for outbound connections if firewall is stateless
|
||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||
return
|
||||
if fw.IsStateful() {
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
} else {
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -94,12 +91,13 @@ func TestDefaultManager(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied")
|
||||
return
|
||||
expectedRules := 2
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 1 // only the inbound rule
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
if previousCount != 1 {
|
||||
t.Errorf("old rule was not removed")
|
||||
|
||||
expectedPreviousCount := 0
|
||||
if !fw.IsStateful() {
|
||||
expectedPreviousCount = 1
|
||||
}
|
||||
assert.Equal(t, expectedPreviousCount, previousCount)
|
||||
})
|
||||
|
||||
t.Run("handle default rules", func(t *testing.T) {
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = true
|
||||
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
if len(acl.peerRulesPairs) != 1 {
|
||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
|
||||
expectedRules := 1
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 1 // only inbound allow-all rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultManagerStateless(t *testing.T) {
|
||||
// stateless currently only in userspace, so we have to disable kernel
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
Port: "53",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// In stateless mode, we should have both inbound and outbound rules
|
||||
assert.False(t, fw.IsStateful())
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
if len(rules) != 2 {
|
||||
t.Errorf("rules should contain 2, got: %v", rules)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, 2, len(rules))
|
||||
|
||||
r := rules[0]
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_IN:
|
||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
|
||||
r = rules[1]
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||
return
|
||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
@@ -291,9 +326,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
||||
}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
@@ -336,33 +370,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
if len(acl.peerRulesPairs) != 3 {
|
||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
return
|
||||
expectedRules := 3
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
}
|
||||
|
||||
@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
||||
if runtime.GOOS == "freebsd" {
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
}
|
||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
@@ -7,15 +7,36 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/management/client/common"
|
||||
)
|
||||
|
||||
func TestPromptLogin(t *testing.T) {
|
||||
const (
|
||||
promptLogin = "prompt=login"
|
||||
maxAge0 = "max_age=0"
|
||||
)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
prompt bool
|
||||
name string
|
||||
loginFlag mgm.LoginFlag
|
||||
disablePromptLogin bool
|
||||
expect string
|
||||
}{
|
||||
{"PromptLogin", true},
|
||||
{"NoPromptLogin", false},
|
||||
{
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
expect: promptLogin,
|
||||
},
|
||||
{
|
||||
name: "Max age 0 login",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expect: maxAge0,
|
||||
},
|
||||
{
|
||||
name: "Disable prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
disablePromptLogin: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
DisablePromptLogin: !tc.prompt,
|
||||
LoginFlag: tc.loginFlag,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
pattern := "prompt=login"
|
||||
if tc.prompt {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||
|
||||
if !tc.disablePromptLogin {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||
} else {
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -68,12 +68,14 @@ type ConfigInput struct {
|
||||
DisableServerRoutes *bool
|
||||
DisableDNS *bool
|
||||
DisableFirewall *bool
|
||||
|
||||
BlockLANAccess *bool
|
||||
BlockLANAccess *bool
|
||||
BlockInbound *bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
LazyConnectionEnabled *bool
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -96,8 +98,8 @@ type Config struct {
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
|
||||
BlockLANAccess bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
@@ -138,6 +140,8 @@ type Config struct {
|
||||
ClientCertKeyPath string
|
||||
|
||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@@ -479,6 +483,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
|
||||
if *input.BlockInbound {
|
||||
log.Infof("blocking inbound connections")
|
||||
} else {
|
||||
log.Infof("allowing inbound connections")
|
||||
}
|
||||
config.BlockInbound = *input.BlockInbound
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
@@ -524,6 +538,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
303
client/internal/conn_mgr.go
Normal file
303
client/internal/conn_mgr.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
)
|
||||
|
||||
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||
//
|
||||
// The connection manager is responsible for:
|
||||
// - Managing lazy connections via the lazyConnManager
|
||||
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||
// - Handling connection establishment based on peer signaling
|
||||
//
|
||||
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||
type ConnMgr struct {
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
dispatcher *dispatcher.ConnectionDispatcher
|
||||
enabledLocally bool
|
||||
|
||||
lazyConnMgr *manager.Manager
|
||||
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||
e := &ConnMgr{
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
dispatcher: dispatcher,
|
||||
}
|
||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||
e.enabledLocally = true
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||
func (e *ConnMgr) Start(ctx context.Context) {
|
||||
if e.lazyConnMgr != nil {
|
||||
log.Errorf("lazy connection manager is already started")
|
||||
return
|
||||
}
|
||||
|
||||
if !e.enabledLocally {
|
||||
log.Infof("lazy connection manager is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
}
|
||||
|
||||
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
|
||||
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||
// do not disable lazy connection manager if it was enabled by env var
|
||||
if e.enabledLocally {
|
||||
return nil
|
||||
}
|
||||
|
||||
if enabled {
|
||||
// if the lazy connection manager is already started, do not start it again
|
||||
if e.lazyConnMgr != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
return e.addPeersToLazyConnManager(ctx)
|
||||
} else {
|
||||
if e.lazyConnMgr == nil {
|
||||
return nil
|
||||
}
|
||||
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||
e.closeManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
|
||||
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
|
||||
if e.lazyConnMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
|
||||
|
||||
for peerID := range peerIDs {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
lazyPeerCfg := lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||
PeerConnID: peerConn.ConnID(),
|
||||
Log: peerConn.Log,
|
||||
}
|
||||
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||
}
|
||||
|
||||
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
|
||||
for _, peerID := range added {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
|
||||
continue
|
||||
}
|
||||
|
||||
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
|
||||
if err := peerConn.Open(e.ctx); err != nil {
|
||||
peerConn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
|
||||
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
|
||||
return true
|
||||
}
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !lazyconn.IsSupported(conn.AgentVersionString()) {
|
||||
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lazyPeerCfg := lazyconn.PeerConfig{
|
||||
PublicKey: peerKey,
|
||||
AllowedIPs: conn.WgConfig().AllowedIps,
|
||||
PeerConnID: conn.ConnID(),
|
||||
Log: conn.Log,
|
||||
}
|
||||
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if excluded {
|
||||
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Infof("peer added to lazy conn manager")
|
||||
return
|
||||
}
|
||||
|
||||
func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||
conn, ok := e.peerStore.Remove(peerKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
}
|
||||
|
||||
e.lazyConnMgr.RemovePeer(peerKey)
|
||||
conn.Log.Infof("removed peer from lazy conn manager")
|
||||
}
|
||||
|
||||
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return conn, true
|
||||
}
|
||||
|
||||
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
|
||||
conn.Log.Infof("activated peer from inactive state")
|
||||
if err := conn.Open(e.ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
}
|
||||
return conn, true
|
||||
}
|
||||
|
||||
func (e *ConnMgr) Close() {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
}
|
||||
|
||||
e.ctxCancel()
|
||||
e.wg.Wait()
|
||||
e.lazyConnMgr = nil
|
||||
}
|
||||
|
||||
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
|
||||
cfg := manager.Config{
|
||||
InactivityThreshold: inactivityThresholdEnv(),
|
||||
}
|
||||
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
|
||||
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
e.ctx = ctx
|
||||
e.ctxCancel = cancel
|
||||
|
||||
e.wg.Add(1)
|
||||
go func() {
|
||||
defer e.wg.Done()
|
||||
e.lazyConnMgr.Start(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
|
||||
peers := e.peerStore.PeersPubKey()
|
||||
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
|
||||
for _, peerID := range peers {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
lazyPeerCfg := lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||
PeerConnID: peerConn.ConnID(),
|
||||
Log: peerConn.Log,
|
||||
}
|
||||
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||
}
|
||||
|
||||
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
|
||||
}
|
||||
|
||||
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||
if e.lazyConnMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.ctxCancel()
|
||||
e.wg.Wait()
|
||||
e.lazyConnMgr = nil
|
||||
|
||||
for _, peerID := range e.peerStore.PeersPubKey() {
|
||||
e.peerStore.PeerConnOpen(ctx, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||
return e.lazyConnMgr != nil && e.ctxCancel != nil
|
||||
}
|
||||
|
||||
func inactivityThresholdEnv() *time.Duration {
|
||||
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||
if envValue == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedMinutes, err := strconv.Atoi(envValue)
|
||||
if err != nil || parsedMinutes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
d := time.Duration(parsedMinutes) * time.Minute
|
||||
return &d
|
||||
}
|
||||
@@ -436,11 +436,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||
DisableDNS: config.DisableDNS,
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
BlockInbound: config.BlockInbound,
|
||||
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -481,7 +483,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
return signalClient, nil
|
||||
}
|
||||
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -269,11 +270,16 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addWgShow(); err != nil {
|
||||
log.Errorf("Failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "console" {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -365,17 +371,34 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
|
||||
if g.internalConfig.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
|
||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
|
||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
||||
|
||||
if g.internalConfig.DisableNotifications != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
|
||||
|
||||
if g.internalConfig.ClientCertPath != "" {
|
||||
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
|
||||
}
|
||||
if g.internalConfig.ClientCertKeyPath != "" {
|
||||
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addProf() (err error) {
|
||||
@@ -533,6 +556,33 @@ func (g *BundleGenerator) addLogfile() error {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
// add latest rotated log file
|
||||
pattern := filepath.Join(logDir, "client-*.log.gz")
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("failed to glob rotated logs: %v", err)
|
||||
} else if len(files) > 0 {
|
||||
// pick the file with the latest ModTime
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
fi, err := os.Stat(files[i])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
|
||||
return false
|
||||
}
|
||||
fj, err := os.Stat(files[j])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
|
||||
return false
|
||||
}
|
||||
return fi.ModTime().Before(fj.ModTime())
|
||||
})
|
||||
latest := files[len(files)-1]
|
||||
name := filepath.Base(latest)
|
||||
if err := g.addSingleLogFileGz(latest, name); err != nil {
|
||||
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
stdErrLogPath := filepath.Join(logDir, errorLogFile)
|
||||
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
||||
if runtime.GOOS == "darwin" {
|
||||
@@ -563,16 +613,13 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
var logReader io.Reader
|
||||
var logReader io.Reader = logFile
|
||||
if g.anonymize {
|
||||
var writer *io.PipeWriter
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go anonymizeLog(logFile, writer, g.anonymizer)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(logReader, targetName); err != nil {
|
||||
return fmt.Errorf("add %s to zip: %w", targetName, err)
|
||||
}
|
||||
@@ -580,6 +627,44 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addSingleLogFileGz adds a single gzipped log file to the archive
|
||||
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
|
||||
f, err := os.Open(logPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open gz log file %s: %w", targetName, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
var logReader io.Reader = gzr
|
||||
if g.anonymize {
|
||||
var pw *io.PipeWriter
|
||||
logReader, pw = io.Pipe()
|
||||
go anonymizeLog(gzr, pw, g.anonymizer)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
if _, err := io.Copy(gw, logReader); err != nil {
|
||||
return fmt.Errorf("re-gzip: %w", err)
|
||||
}
|
||||
|
||||
if err := gw.Close(); err != nil {
|
||||
return fmt.Errorf("close gzip writer: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(&buf, targetName); err != nil {
|
||||
return fmt.Errorf("add anonymized gz: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
|
||||
66
client/internal/debug/wgshow.go
Normal file
66
client/internal/debug/wgshow.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addWgShow() error {
|
||||
result, err := g.statusRecorder.PeersStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output := g.toWGShowFormat(result)
|
||||
reader := bytes.NewReader([]byte(output))
|
||||
|
||||
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
|
||||
return fmt.Errorf("add wg show to zip: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
|
||||
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
|
||||
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
|
||||
if s.FWMark != 0 {
|
||||
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
|
||||
}
|
||||
|
||||
for _, peer := range s.Peers {
|
||||
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
|
||||
if peer.Endpoint.IP != nil {
|
||||
if g.anonymize {
|
||||
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
|
||||
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
|
||||
}
|
||||
}
|
||||
if len(peer.AllowedIPs) > 0 {
|
||||
var ipStrings []string
|
||||
for _, ipnet := range peer.AllowedIPs {
|
||||
ipStrings = append(ipStrings, ipnet.String())
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||
if peer.PresharedKey {
|
||||
sb.WriteString(" preshared key: (hidden)\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -12,13 +12,14 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||
ip, err := netip.ParseAddr(aRecord.RData)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
if !prefix.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||
networkIP := network.Masked().Addr()
|
||||
|
||||
if !networkIP.Is4() {
|
||||
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||
}
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
octetsToUse := (network.Bits() + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||
zoneName, err := generateReverseZoneName(network)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
records := collectPTRRecords(config, network)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
|
||||
@@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
|
||||
@@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
IP: netip.MustParseAddr("100.66.100.1"),
|
||||
Network: netip.MustParsePrefix("100.66.100.0/24"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||
if err != nil {
|
||||
t.Errorf("parse CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
|
||||
@@ -24,11 +24,15 @@ type ServiceViaMemory struct {
|
||||
}
|
||||
|
||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
|
||||
runtimeIP: lastIP.String(),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
@@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
if s.wgInterface.Address().IP.Is6() {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,9 +30,12 @@ const (
|
||||
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
|
||||
dnsSecDisabled = "no"
|
||||
)
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
@@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
Family: unix.AF_INET,
|
||||
Address: ipAs4[:],
|
||||
}
|
||||
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
|
||||
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
||||
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
|
||||
}
|
||||
|
||||
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
|
||||
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -23,8 +24,8 @@ type upstreamResolver struct {
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
|
||||
@@ -4,7 +4,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -19,8 +19,8 @@ type upstreamResolver struct {
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -18,16 +19,16 @@ import (
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
lIP netip.Addr
|
||||
lNet netip.Prefix
|
||||
interfaceName string
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip net.IP,
|
||||
net *net.IPNet,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
client.DialTimeout = timeout
|
||||
|
||||
upstreamIP := net.ParseIP(upstreamHost)
|
||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||
}
|
||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
@@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
@@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: ip,
|
||||
IP: ip.AsSlice(),
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
|
||||
@@ -2,7 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
|
||||
@@ -5,7 +5,6 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -18,5 +17,4 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -13,6 +12,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
@@ -120,8 +121,10 @@ type EngineConfig struct {
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
BlockLANAccess bool
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -134,6 +137,8 @@ type Engine struct {
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerStore *peerstore.Store
|
||||
|
||||
connMgr *ConnMgr
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
|
||||
@@ -170,7 +175,8 @@ type Engine struct {
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
|
||||
statusRecorder *peer.Status
|
||||
statusRecorder *peer.Status
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
|
||||
firewall firewallManager.Manager
|
||||
routeManager routemanager.Manager
|
||||
@@ -235,6 +241,8 @@ func NewEngine(
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
}
|
||||
|
||||
path := statemanager.GetDefaultStatePath()
|
||||
if runtime.GOOS == "ios" {
|
||||
if !fileExists(mobileDep.StateFilePath) {
|
||||
err := createFile(mobileDep.StateFilePath)
|
||||
@@ -244,11 +252,9 @@ func NewEngine(
|
||||
}
|
||||
}
|
||||
|
||||
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
|
||||
}
|
||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||
engine.stateManager = statemanager.New(path)
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
|
||||
return engine
|
||||
}
|
||||
@@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
|
||||
// stopping network monitor first to avoid starting the engine again
|
||||
if e.networkMonitor != nil {
|
||||
e.networkMonitor.Stop()
|
||||
@@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
if err := e.removeAllPeers(); err != nil {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
}
|
||||
|
||||
@@ -350,6 +359,7 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("new wg interface: %w", err)
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
e.statusRecorder.SetWgIface(wgIface)
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
@@ -371,7 +381,6 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
e.stateManager.Start()
|
||||
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
@@ -405,8 +414,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
err = e.wgInterfaceCreate()
|
||||
if err != nil {
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
@@ -423,7 +431,8 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
// if inbound conns are blocked there is no need to create the ACL manager
|
||||
if e.firewall != nil && !e.config.BlockInbound {
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
@@ -442,6 +451,11 @@ func (e *Engine) Start() error {
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
|
||||
@@ -450,7 +464,6 @@ func (e *Engine) Start() error {
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -475,11 +488,9 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if e.firewall.IsServerRouteSupported() {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
|
||||
if e.config.BlockLANAccess {
|
||||
@@ -513,6 +524,11 @@ func (e *Engine) initFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
return
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// TODO: keep this updated
|
||||
@@ -550,6 +566,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
var modified []*mgmProto.RemotePeerConfig
|
||||
for _, p := range peersUpdate {
|
||||
peerPubKey := p.GetWgPubKey()
|
||||
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPeer.AgentVersionString() != p.AgentVersion {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
|
||||
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
||||
if !ok {
|
||||
continue
|
||||
@@ -559,8 +585,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
|
||||
if err != nil {
|
||||
if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
|
||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
||||
}
|
||||
}
|
||||
@@ -621,16 +646,11 @@ func (e *Engine) removePeer(peerKey string) error {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
if err != nil {
|
||||
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
||||
}
|
||||
}()
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
conn, exists := e.peerStore.Remove(peerKey)
|
||||
if exists {
|
||||
conn.Close()
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
if err != nil {
|
||||
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -780,56 +800,58 @@ func isNil(server nbssh.Server) bool {
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Warnf("running SSH server is not permitted")
|
||||
log.Info("SSH server is not enabled")
|
||||
return nil
|
||||
} else {
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshServer = nil
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshServer = nil
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
@@ -952,12 +974,33 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||
log.Errorf("failed to update local IPs: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
|
||||
// This needs to be toggled before applying routes.
|
||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
@@ -965,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
@@ -976,7 +1019,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
|
||||
@@ -1022,14 +1066,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
|
||||
e.connMgr.SetExcludeList(excludedLazyPeers)
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
@@ -1065,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
Network: prefix.Masked(),
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
@@ -1099,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
return entries
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
@@ -1155,7 +1194,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||
PubKey: offlinePeer.GetWgPubKey(),
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusDisconnected,
|
||||
ConnStatus: peer.StatusIdle,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
@@ -1191,12 +1230,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
peerIPs = append(peerIPs, allowedNetIP)
|
||||
}
|
||||
|
||||
conn, err := e.createPeerConn(peerKey, peerIPs)
|
||||
conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
|
||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
@@ -1205,17 +1249,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
conn.Open()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
|
||||
log.Debugf("creating peer connection %s", pubKey)
|
||||
|
||||
wgConfig := peer.WgConfig{
|
||||
@@ -1229,11 +1266,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
||||
// randomize connection timeout
|
||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||
config := peer.ConnConfig{
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
AgentVersion: agentVersion,
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassConfig: peer.RosenpassConfig{
|
||||
PubKey: e.getRosenpassPubKey(),
|
||||
Addr: e.getRosenpassAddr(),
|
||||
@@ -1249,7 +1287,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
||||
},
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
PeerConnDispatcher: e.peerConnDispatcher,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1270,7 +1317,7 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
}
|
||||
@@ -1406,6 +1453,7 @@ func (e *Engine) close() {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
}
|
||||
e.wgInterface = nil
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
@@ -1578,13 +1626,39 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes() bool {
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
||||
|
||||
managementHealthy := e.mgmClient.IsHealthy()
|
||||
log.Debugf("management health check: healthy=%t", managementHealthy)
|
||||
|
||||
results := append(e.probeSTUNs(), e.probeTURNs()...)
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
turns := slices.Clone(e.TURNs)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
stats, err := e.wgInterface.GetStats()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get wireguard stats: %v", err)
|
||||
e.syncMsgMux.Unlock()
|
||||
return false
|
||||
}
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
wgStats, ok := stats[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
results := e.probeICE(stuns, turns)
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
@@ -1596,37 +1670,16 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
wgStats, err := e.wgInterface.GetStats(key)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
||||
e.syncMsgMux.Lock()
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
|
||||
}
|
||||
|
||||
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
e.syncMsgMux.Lock()
|
||||
turns := slices.Clone(e.TURNs)
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
||||
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
return append(
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
@@ -1738,9 +1791,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
}
|
||||
|
||||
// GetWgAddr returns the wireguard address
|
||||
func (e *Engine) GetWgAddr() net.IP {
|
||||
func (e *Engine) GetWgAddr() netip.Addr {
|
||||
if e.wgInterface == nil {
|
||||
return nil
|
||||
return netip.Addr{}
|
||||
}
|
||||
return e.wgInterface.Address().IP
|
||||
}
|
||||
@@ -1750,6 +1803,10 @@ func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1805,29 +1862,24 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
addr := e.wgInterface.Address()
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
|
||||
}
|
||||
return ip.Unmap(), nil
|
||||
return e.wgInterface.Address().IP, nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
if e.firewall == nil {
|
||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(rules) == 0 {
|
||||
if e.ingressGatewayMgr == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err := e.ingressGatewayMgr.Close()
|
||||
e.ingressGatewayMgr = nil
|
||||
e.statusRecorder.SetIngressGwMgr(nil)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.ingressGatewayMgr == nil {
|
||||
@@ -1878,7 +1930,35 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||
log.Errorf("failed to update forwarding rules: %v", err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return forwardingRules, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
|
||||
excludedPeers := make(map[string]bool)
|
||||
for _, r := range routes {
|
||||
if r.Peer == "" {
|
||||
continue
|
||||
}
|
||||
if !excludedPeers[r.Peer] {
|
||||
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
|
||||
excludedPeers[r.Peer] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
ip := r.TranslatedAddress
|
||||
for _, p := range peers {
|
||||
for _, allowedIP := range p.GetAllowedIps() {
|
||||
if allowedIP != ip.String() {
|
||||
continue
|
||||
}
|
||||
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
|
||||
excludedPeers[p.GetWgPubKey()] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return excludedPeers
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
|
||||
@@ -28,8 +28,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -38,6 +36,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -53,6 +52,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
@@ -93,12 +93,16 @@ type MockWGIface struct {
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
@@ -171,8 +175,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||
return m.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return m.GetStatsFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
@@ -371,13 +375,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
@@ -400,6 +404,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -770,6 +776,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
|
||||
engine.routeManager = mockRouteManager
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
exitErr := engine.Stop()
|
||||
@@ -966,6 +974,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
engine.dnsServer = mockDNSServer
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
exitErr := engine.Stop()
|
||||
@@ -1476,7 +1486,7 @@ func getConnectedPeers(e *Engine) int {
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.Status() == peer.StatusConnected {
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ type wgIfaceBase interface {
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetNet() *netstack.Net
|
||||
FullStats() (*configurer.Stats, error)
|
||||
}
|
||||
|
||||
9
client/internal/lazyconn/activity/listen_ip.go
Normal file
9
client/internal/lazyconn/activity/listen_ip.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package activity
|
||||
|
||||
import "net"
|
||||
|
||||
var (
|
||||
listenIP = net.IP{127, 0, 0, 1}
|
||||
)
|
||||
10
client/internal/lazyconn/activity/listen_ip_linux.go
Normal file
10
client/internal/lazyconn/activity/listen_ip_linux.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !android
|
||||
|
||||
package activity
|
||||
|
||||
import "net"
|
||||
|
||||
var (
|
||||
// use this ip to avoid eBPF proxy congestion
|
||||
listenIP = net.IP{127, 0, 1, 1}
|
||||
)
|
||||
106
client/internal/lazyconn/activity/listener.go
Normal file
106
client/internal/lazyconn/activity/listener.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||
type Listener struct {
|
||||
wgIface lazyconn.WGIface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
conn *net.UDPConn
|
||||
endpoint *net.UDPAddr
|
||||
done sync.Mutex
|
||||
|
||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||
}
|
||||
|
||||
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
d := &Listener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
}
|
||||
|
||||
conn, err := d.newConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
|
||||
}
|
||||
d.conn = conn
|
||||
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
if err := d.createEndpoint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.done.Lock()
|
||||
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Listener) ReadPackets() {
|
||||
for {
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
if err != nil {
|
||||
if d.isClosed.Load() {
|
||||
d.peerCfg.Log.Debugf("exit from activity listener")
|
||||
} else {
|
||||
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if n < 1 {
|
||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := d.removeEndpoint(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
|
||||
d.done.Unlock()
|
||||
}
|
||||
|
||||
func (d *Listener) Close() {
|
||||
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
|
||||
d.isClosed.Store(true)
|
||||
|
||||
if err := d.conn.Close(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
|
||||
}
|
||||
d.done.Lock()
|
||||
}
|
||||
|
||||
func (d *Listener) removeEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (d *Listener) createEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
||||
}
|
||||
|
||||
func (d *Listener) newConn() (*net.UDPConn, error) {
|
||||
addr := &net.UDPAddr{
|
||||
Port: 0,
|
||||
IP: listenIP,
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create activity listener on %s: %s", addr, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
41
client/internal/lazyconn/activity/listener_test.go
Normal file
41
client/internal/lazyconn/activity/listener_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestNewListener(t *testing.T) {
|
||||
peer := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
l, err := NewListener(MocWGIface{}, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
chanClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(chanClosed)
|
||||
l.ReadPackets()
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
l.Close()
|
||||
|
||||
select {
|
||||
case <-chanClosed:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
95
client/internal/lazyconn/activity/manager.go
Normal file
95
client/internal/lazyconn/activity/manager.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
OnActivityChan chan peerid.ConnID
|
||||
|
||||
wgIface lazyconn.WGIface
|
||||
|
||||
peers map[peerid.ConnID]*Listener
|
||||
done chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(wgIface lazyconn.WGIface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
peers: make(map[peerid.ConnID]*Listener),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
|
||||
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
listener, err := NewListener(m.wgIface, peerCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.peers[peerCfg.PeerConnID] = listener
|
||||
|
||||
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
listener, ok := m.peers[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Debugf("removing activity listener")
|
||||
delete(m.peers, peerConnID)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
func (m *Manager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
close(m.done)
|
||||
for peerID, listener := range m.peers {
|
||||
delete(m.peers, peerID)
|
||||
listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
|
||||
listener.ReadPackets()
|
||||
|
||||
m.mu.Lock()
|
||||
if _, ok := m.peers[peerConnID]; !ok {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(m.peers, peerConnID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.notify(peerConnID)
|
||||
}
|
||||
|
||||
func (m *Manager) notify(peerConnID peerid.ConnID) {
|
||||
select {
|
||||
case <-m.done:
|
||||
case m.OnActivityChan <- peerConnID:
|
||||
}
|
||||
}
|
||||
162
client/internal/lazyconn/activity/manager_test.go
Normal file
162
client/internal/lazyconn/activity/manager_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
return peerid.ConnID(m)
|
||||
}
|
||||
|
||||
type MocWGIface struct {
|
||||
}
|
||||
|
||||
func (m MocWGIface) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
peer1 := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
mgr := NewManager(mocWgInterface)
|
||||
defer mgr.Close()
|
||||
peerCfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
if peerConnID != peerCfg1.PeerConnID {
|
||||
t.Fatalf("unexpected peerConnID: %v", peerConnID)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_RemovePeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
peer1 := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
mgr := NewManager(mocWgInterface)
|
||||
defer mgr.Close()
|
||||
|
||||
peerCfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
||||
|
||||
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||
|
||||
if err := trigger(addr); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-mgr.OnActivityChan:
|
||||
t.Fatal("should not have active activity")
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
peer1 := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
mgr := NewManager(mocWgInterface)
|
||||
defer mgr.Close()
|
||||
|
||||
peerCfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
peer2 := &MocPeer{}
|
||||
peerCfg2 := lazyconn.PeerConfig{
|
||||
PublicKey: peer2.PeerID,
|
||||
PeerConnID: peer2.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey2"),
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-mgr.OnActivityChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timed out waiting for activity")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func trigger(addr string) error {
|
||||
// Create a connection to the destination UDP address and port
|
||||
conn, err := net.Dial("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write the bytes to the UDP connection
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
32
client/internal/lazyconn/doc.go
Normal file
32
client/internal/lazyconn/doc.go
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
|
||||
|
||||
## Overview
|
||||
|
||||
The package includes a `Manager` component responsible for:
|
||||
- Managing lazy connections activated on-demand
|
||||
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
|
||||
- Maintaining a list of excluded peers that should always have permanent connections
|
||||
- Handling remote peer connection initiatives based on peer signaling
|
||||
|
||||
## Thread-Safe Operations
|
||||
|
||||
The `Manager` ensures thread safety across multiple operations, categorized by caller:
|
||||
|
||||
- **Engine (single goroutine)**:
|
||||
- `AddPeer`: Adds a peer to the connection manager.
|
||||
- `RemovePeer`: Removes a peer from the connection manager.
|
||||
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
|
||||
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
|
||||
|
||||
- **Connection Dispatcher (any peer routine)**:
|
||||
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
|
||||
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
|
||||
|
||||
- **Activity Manager**:
|
||||
- `onPeerActivity`: Run peer.Open(context).
|
||||
|
||||
- **Inactivity Monitor**:
|
||||
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
|
||||
*/
|
||||
package lazyconn
|
||||
26
client/internal/lazyconn/env.go
Normal file
26
client/internal/lazyconn/env.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
|
||||
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
|
||||
)
|
||||
|
||||
func IsLazyConnEnabledByEnv() bool {
|
||||
val := os.Getenv(EnvEnableLazyConn)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
enabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
peer "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
|
||||
MinimumInactivityThreshold = 3 * time.Minute
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
id peer.ConnID
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
inactivityThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
|
||||
i := &Monitor{
|
||||
id: peerID,
|
||||
timer: time.NewTimer(0),
|
||||
inactivityThreshold: threshold,
|
||||
}
|
||||
i.timer.Stop()
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
defer i.timer.Stop()
|
||||
|
||||
ctx, i.cancel = context.WithCancel(ctx)
|
||||
defer func() {
|
||||
defer i.cancel()
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
select {
|
||||
case timeoutChan <- i.id:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Monitor) Stop() {
|
||||
if i.cancel == nil {
|
||||
return
|
||||
}
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
func (i *Monitor) PauseTimer() {
|
||||
i.timer.Stop()
|
||||
}
|
||||
|
||||
func (i *Monitor) ResetTimer() {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
}
|
||||
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
}
|
||||
|
||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
return peerid.ConnID(m)
|
||||
}
|
||||
|
||||
func TestInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseInactivityMonitor(t *testing.T) {
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
for i := 2; i > 0; i-- {
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(testTimeoutCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
testTimeoutCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
im.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPauseInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
trashHold := time.Second * 3
|
||||
im := NewInactivityMonitor(p.ConnID(), trashHold)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(ctx, timeoutChan)
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second) // grant time to start the monitor
|
||||
im.PauseTimer()
|
||||
|
||||
// check to do not receive timeout
|
||||
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
|
||||
defer thresholdCancel()
|
||||
select {
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-thresholdCtx.Done():
|
||||
// test ok
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
}
|
||||
|
||||
// test reset timer
|
||||
im.ResetTimer()
|
||||
|
||||
select {
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
// expected timeout
|
||||
}
|
||||
}
|
||||
404
client/internal/lazyconn/manager/manager.go
Normal file
404
client/internal/lazyconn/manager/manager.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
)
|
||||
|
||||
const (
|
||||
watcherActivity watcherType = iota
|
||||
watcherInactivity
|
||||
)
|
||||
|
||||
type watcherType int
|
||||
|
||||
type managedPeer struct {
|
||||
peerCfg *lazyconn.PeerConfig
|
||||
expectedWatcher watcherType
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
InactivityThreshold *time.Duration
|
||||
}
|
||||
|
||||
// Manager manages lazy connections
|
||||
// It is responsible for:
|
||||
// - Managing lazy connections activated on-demand
|
||||
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
|
||||
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||
// - Handling connection establishment based on peer signaling
|
||||
type Manager struct {
|
||||
peerStore *peerstore.Store
|
||||
connStateDispatcher *dispatcher.ConnectionDispatcher
|
||||
inactivityThreshold time.Duration
|
||||
|
||||
connStateListener *dispatcher.ConnectionListener
|
||||
managedPeers map[string]*lazyconn.PeerConfig
|
||||
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
||||
excludes map[string]lazyconn.PeerConfig
|
||||
managedPeersMu sync.Mutex
|
||||
|
||||
activityManager *activity.Manager
|
||||
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
|
||||
|
||||
cancel context.CancelFunc
|
||||
onInactive chan peerid.ConnID
|
||||
}
|
||||
|
||||
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
|
||||
log.Infof("setup lazy connection service")
|
||||
m := &Manager{
|
||||
peerStore: peerStore,
|
||||
connStateDispatcher: connStateDispatcher,
|
||||
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
||||
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
||||
excludes: make(map[string]lazyconn.PeerConfig),
|
||||
activityManager: activity.NewManager(wgIface),
|
||||
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
|
||||
onInactive: make(chan peerid.ConnID),
|
||||
}
|
||||
|
||||
if config.InactivityThreshold != nil {
|
||||
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
|
||||
m.inactivityThreshold = *config.InactivityThreshold
|
||||
} else {
|
||||
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
m.connStateListener = &dispatcher.ConnectionListener{
|
||||
OnConnected: m.onPeerConnected,
|
||||
OnDisconnected: m.onPeerDisconnected,
|
||||
}
|
||||
|
||||
connStateDispatcher.AddListener(m.connStateListener)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Start starts the manager and listens for peer activity and inactivity events
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
defer m.close()
|
||||
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(ctx, peerConnID)
|
||||
case peerConnID := <-m.onInactive:
|
||||
m.onPeerInactivityTimedOut(peerConnID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExcludePeer marks peers for a permanent connection
|
||||
// It removes peers from the managed list if they are added to the exclude list
|
||||
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
||||
// this case, we suppose that the connection status is connected or connecting.
|
||||
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
||||
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
added := make([]string, 0)
|
||||
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
|
||||
|
||||
for _, peerCfg := range peerConfigs {
|
||||
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
|
||||
excludes[peerCfg.PublicKey] = peerCfg
|
||||
}
|
||||
|
||||
// if a peer is newly added to the exclude list, remove from the managed peers list
|
||||
for pubKey, peerCfg := range excludes {
|
||||
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
|
||||
continue
|
||||
}
|
||||
|
||||
added = append(added, pubKey)
|
||||
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
|
||||
m.removePeer(pubKey)
|
||||
}
|
||||
|
||||
// if a peer has been removed from exclude list then it should be added to the managed peers
|
||||
for pubKey, peerCfg := range m.excludes {
|
||||
if _, stillExcluded := excludes[pubKey]; stillExcluded {
|
||||
continue
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
||||
|
||||
if err := m.addActivePeer(ctx, peerCfg); err != nil {
|
||||
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
m.excludes = excludes
|
||||
return added
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
peerCfg.Log.Debugf("adding peer to lazy connection manager")
|
||||
|
||||
_, exists := m.excludes[peerCfg.PublicKey]
|
||||
if exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherActivity,
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// AddActivePeers adds a list of peers to the lazy connection manager
|
||||
// suppose these peers was in connected or in connecting states
|
||||
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
for _, cfg := range peerCfg {
|
||||
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
|
||||
cfg.Log.Errorf("peer already managed")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.addActivePeer(ctx, cfg); err != nil {
|
||||
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(peerID string) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
m.removePeer(peerID)
|
||||
}
|
||||
|
||||
// ActivatePeer activates a peer connection when a signal message is received
|
||||
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
cfg, ok := m.managedPeers[peerID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// signal messages coming continuously after success activation, with this avoid the multiple activation
|
||||
if mp.expectedWatcher == watcherInactivity {
|
||||
return false
|
||||
}
|
||||
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
|
||||
im, ok := m.inactivityMonitors[cfg.PeerConnID]
|
||||
if !ok {
|
||||
cfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return false
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("starting inactivity monitor")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
return nil
|
||||
}
|
||||
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherInactivity,
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) removePeer(peerID string) {
|
||||
cfg, ok := m.managedPeers[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Log.Infof("removing lazy peer")
|
||||
|
||||
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
|
||||
im.Stop()
|
||||
delete(m.inactivityMonitors, cfg.PeerConnID)
|
||||
cfg.Log.Debugf("inactivity monitor stopped")
|
||||
}
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
delete(m.managedPeers, peerID)
|
||||
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
||||
}
|
||||
|
||||
func (m *Manager) close() {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
m.cancel()
|
||||
|
||||
m.connStateDispatcher.RemoveListener(m.connStateListener)
|
||||
m.activityManager.Close()
|
||||
for _, iw := range m.inactivityMonitors {
|
||||
iw.Stop()
|
||||
}
|
||||
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
|
||||
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
||||
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
||||
log.Infof("lazy connection manager closed")
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by conn id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherActivity {
|
||||
mp.peerCfg.Log.Warnf("ignore activity event")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("detected peer activity")
|
||||
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
|
||||
mp.peerCfg.Log.Infof("starting inactivity monitor")
|
||||
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
|
||||
|
||||
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
|
||||
// this is blocking operation, potentially can be optimized
|
||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
// just in case free up
|
||||
m.inactivityMonitors[peerConnID].PauseTimer()
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
|
||||
iw.PauseTimer()
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
|
||||
iw.ResetTimer()
|
||||
}
|
||||
16
client/internal/lazyconn/peercfg.go
Normal file
16
client/internal/lazyconn/peercfg.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type PeerConfig struct {
|
||||
PublicKey string
|
||||
AllowedIPs []netip.Prefix
|
||||
PeerConnID id.ConnID
|
||||
Log *log.Entry
|
||||
}
|
||||
41
client/internal/lazyconn/support.go
Normal file
41
client/internal/lazyconn/support.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
var (
|
||||
minVersion = version.Must(version.NewVersion("0.45.0"))
|
||||
)
|
||||
|
||||
func IsSupported(agentVersion string) bool {
|
||||
if agentVersion == "development" {
|
||||
return true
|
||||
}
|
||||
|
||||
// filter out versions like this: a6c5960, a7d5c522, d47be154
|
||||
if !strings.Contains(agentVersion, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
normalizedVersion := normalizeVersion(agentVersion)
|
||||
inputVer, err := version.NewVersion(normalizedVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return inputVer.GreaterThanOrEqual(minVersion)
|
||||
}
|
||||
|
||||
func normalizeVersion(version string) string {
|
||||
// Remove prefixes like 'v' or 'a'
|
||||
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
|
||||
version = version[1:]
|
||||
}
|
||||
|
||||
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
|
||||
parts := strings.Split(version, "-")
|
||||
return parts[0]
|
||||
}
|
||||
31
client/internal/lazyconn/support_test.go
Normal file
31
client/internal/lazyconn/support_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package lazyconn
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
version string
|
||||
want bool
|
||||
}{
|
||||
{"development", true},
|
||||
{"0.45.0", true},
|
||||
{"v0.45.0", true},
|
||||
{"0.45.1", true},
|
||||
{"0.45.1-SNAPSHOT-559e6731", true},
|
||||
{"v0.45.1-dev", true},
|
||||
{"a7d5c522", false},
|
||||
{"0.9.6", false},
|
||||
{"0.9.6-SNAPSHOT", false},
|
||||
{"0.9.6-SNAPSHOT-2033650", false},
|
||||
{"meta_wt_version", false},
|
||||
{"v0.31.1-dev", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.version, func(t *testing.T) {
|
||||
if got := IsSupported(tt.version); got != tt.want {
|
||||
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
14
client/internal/lazyconn/wgiface.go
Normal file
14
client/internal/lazyconn/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
||||
|
||||
// fallback if mark rules are not in place
|
||||
wgnet := c.iface.Address().Network
|
||||
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
|
||||
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
@@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
|
||||
// fallback if marks are not set
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||
|
||||
switch {
|
||||
case wgaddr.Equal(src):
|
||||
case wgaddr == srcIP:
|
||||
return nftypes.Egress
|
||||
case wgaddr.Equal(dst):
|
||||
case wgaddr == dstIP:
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(src):
|
||||
case wgnetwork.Contains(srcIP):
|
||||
// netbird network -> resource network
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(dst):
|
||||
case wgnetwork.Contains(dstIP):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -23,17 +23,16 @@ type Logger struct {
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
wgIfaceNet netip.Prefix
|
||||
dnsCollection atomic.Bool
|
||||
exitNodeCollection atomic.Bool
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
|
||||
return &Logger{
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
wgIfaceNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
@@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
|
||||
var isSrcExitNode bool
|
||||
var isDestExitNode bool
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
if !l.wgIfaceNet.Contains(event.SourceIP) {
|
||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
if !l.wgIfaceNet.Contains(event.DestIP) {
|
||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger := logger.New(nil, netip.Prefix{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -34,11 +34,11 @@ type Manager struct {
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
var prefix netip.Prefix
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
prefix = iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, prefix)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -123,8 +123,14 @@ func (m *Manager) disableFlow() error {
|
||||
|
||||
m.logger.Close()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
if m.receiverClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.receiverClient.Close()
|
||||
m.receiverClient = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("close: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
@@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
@@ -17,8 +17,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -26,32 +30,20 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
type ConnPriority int
|
||||
|
||||
func (cp ConnPriority) String() string {
|
||||
switch cp {
|
||||
case connPriorityNone:
|
||||
return "None"
|
||||
case connPriorityRelay:
|
||||
return "PriorityRelay"
|
||||
case connPriorityICETurn:
|
||||
return "PriorityICETurn"
|
||||
case connPriorityICEP2P:
|
||||
return "PriorityICEP2P"
|
||||
default:
|
||||
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
|
||||
connPriorityNone ConnPriority = 0
|
||||
connPriorityRelay ConnPriority = 1
|
||||
connPriorityICETurn ConnPriority = 2
|
||||
connPriorityICEP2P ConnPriority = 3
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
StatusRecorder *Status
|
||||
Signaler *Signaler
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
Semaphore *semaphoregroup.SemaphoreGroup
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
}
|
||||
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
@@ -76,6 +68,8 @@ type ConnConfig struct {
|
||||
// LocalKey is a public key of a local peer
|
||||
LocalKey string
|
||||
|
||||
AgentVersion string
|
||||
|
||||
Timeout time.Duration
|
||||
|
||||
WgConfig WgConfig
|
||||
@@ -89,22 +83,23 @@ type ConnConfig struct {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
log *log.Entry
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
handshaker *Handshaker
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
|
||||
statusRelay *AtomicConnStatus
|
||||
statusICE *AtomicConnStatus
|
||||
currentConnPriority ConnPriority
|
||||
statusRelay *worker.AtomicWorkerStatus
|
||||
statusICE *worker.AtomicWorkerStatus
|
||||
currentConnPriority conntype.ConnPriority
|
||||
opened bool // this flag is used to prevent close in case of not opened connection
|
||||
|
||||
workerICE *WorkerICE
|
||||
@@ -120,9 +115,12 @@ type Conn struct {
|
||||
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -130,91 +128,101 @@ type Conn struct {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
if len(config.WgConfig.AllowedIps) == 0 {
|
||||
return nil, fmt.Errorf("allowed IPs is empty")
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(engineCtx)
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
log: connLog,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
semaphore: semaphore,
|
||||
dumpState: newStateDump(config.Key, connLog, statusRecorder),
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
peerConnDispatcher: services.PeerConnDispatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
}
|
||||
|
||||
ctrl := isController(config)
|
||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
|
||||
|
||||
go conn.handshaker.Listen()
|
||||
|
||||
go conn.dumpState.Start(ctx)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Open opens connection to the remote peer
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open() {
|
||||
conn.semaphore.Add(conn.ctx)
|
||||
conn.log.Debugf("open connection to peer")
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.semaphore.Add(engineCtx)
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
conn.opened = true
|
||||
|
||||
if conn.opened {
|
||||
conn.semaphore.Done(engineCtx)
|
||||
return nil
|
||||
}
|
||||
|
||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
conn.handshaker.Listen(conn.ctx)
|
||||
}()
|
||||
go conn.dumpState.Start(conn.ctx)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: StatusDisconnected,
|
||||
ConnStatus: StatusConnecting,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("error while updating the state err: %v", err)
|
||||
if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
|
||||
conn.Log.Warnf("error while updating the state err: %v", err)
|
||||
}
|
||||
|
||||
go conn.startHandshakeAndReconnect(conn.ctx)
|
||||
}
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||
conn.semaphore.Done(conn.ctx)
|
||||
|
||||
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||
defer conn.semaphore.Done(conn.ctx)
|
||||
conn.waitInitialRandomSleepTime(ctx)
|
||||
conn.dumpState.SendOffer()
|
||||
if err := conn.handshaker.sendOffer(); err != nil {
|
||||
conn.Log.Errorf("failed to send initial offer: %v", err)
|
||||
}
|
||||
|
||||
conn.dumpState.SendOffer()
|
||||
err := conn.handshaker.sendOffer()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
||||
}
|
||||
|
||||
go conn.guard.Start(ctx)
|
||||
go conn.listenGuardEvent(ctx)
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
conn.wg.Done()
|
||||
}()
|
||||
}()
|
||||
conn.opened = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
@@ -223,14 +231,14 @@ func (conn *Conn) Close() {
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
conn.log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
if !conn.opened {
|
||||
conn.log.Debugf("ignore close connection to peer")
|
||||
conn.Log.Debugf("ignore close connection to peer")
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
@@ -238,7 +246,7 @@ func (conn *Conn) Close() {
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to close wg proxy for relay: %v", err)
|
||||
conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
|
||||
}
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
@@ -246,13 +254,13 @@ func (conn *Conn) Close() {
|
||||
if conn.wgProxyICE != nil {
|
||||
err := conn.wgProxyICE.CloseConn()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to close wg proxy for ice: %v", err)
|
||||
conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
|
||||
}
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
conn.freeUpConnID()
|
||||
@@ -262,14 +270,16 @@ func (conn *Conn) Close() {
|
||||
}
|
||||
|
||||
conn.setStatusToDisconnected()
|
||||
conn.log.Infof("peer connection has been closed")
|
||||
conn.opened = false
|
||||
conn.wg.Wait()
|
||||
conn.Log.Infof("peer connection closed")
|
||||
}
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
conn.dumpState.RemoteAnswer()
|
||||
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
||||
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
||||
return conn.handshaker.OnRemoteAnswer(answer)
|
||||
}
|
||||
|
||||
@@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
||||
|
||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
||||
conn.dumpState.RemoteOffer()
|
||||
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||
return conn.handshaker.OnRemoteOffer(offer)
|
||||
}
|
||||
|
||||
@@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
|
||||
return conn.config.WgConfig
|
||||
}
|
||||
|
||||
// Status returns current status of the Conn
|
||||
func (conn *Conn) Status() ConnStatus {
|
||||
// IsConnected unit tests only
|
||||
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
|
||||
func (conn *Conn) IsConnected() bool {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
return conn.evalStatus()
|
||||
return conn.currentConnPriority != conntype.None
|
||||
}
|
||||
|
||||
func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
func (conn *Conn) ConnID() id.ConnID {
|
||||
return id.ConnID(conn)
|
||||
}
|
||||
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
||||
func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
return
|
||||
}
|
||||
|
||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||
conn.log.Errorf("remote ICE connection is nil")
|
||||
if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
|
||||
conn.Log.Errorf("remote ICE connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
||||
// todo consider to remove this check
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("set ICE to active connection")
|
||||
conn.Log.Infof("set ICE to active connection")
|
||||
conn.dumpState.P2PConnected()
|
||||
|
||||
var (
|
||||
@@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
conn.dumpState.NewLocalProxy()
|
||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
ep = wgProxy.EndpointAddr()
|
||||
@@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
@@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
oldState := conn.currentConnPriority
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
|
||||
if oldState == conntype.None {
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
@@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Tracef("ICE connection state changed to disconnected")
|
||||
conn.Log.Tracef("ICE connection state changed to disconnected")
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// switch back to relay connection
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.dumpState.SwitchToRelay()
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
@@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != StatusDisconnected
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
if changed {
|
||||
conn.guard.SetICEConnDisconnected()
|
||||
}
|
||||
conn.statusICE.Set(StatusDisconnected)
|
||||
conn.statusICE.SetDisconnected()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
if err := rci.relayedConn.Close(); err != nil {
|
||||
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.dumpState.RelayConnected()
|
||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
|
||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
conn.dumpState.NewLocalProxy()
|
||||
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
|
||||
if conn.isICEActive() {
|
||||
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.Log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
@@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("relay connection is disconnected")
|
||||
conn.Log.Debugf("relay connection is disconnected")
|
||||
|
||||
if conn.currentConnPriority == connPriorityRelay {
|
||||
conn.log.Infof("clean up WireGuard config")
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
changed := conn.statusRelay.Get() != worker.StatusDisconnected
|
||||
if changed {
|
||||
conn.guard.SetRelayedConnDisconnected()
|
||||
}
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.statusRelay.SetDisconnected()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-conn.guard.Reconnect:
|
||||
conn.log.Infof("send offer to peer")
|
||||
conn.dumpState.SendOffer()
|
||||
if err := conn.handshaker.SendOffer(); err != nil {
|
||||
conn.log.Errorf("failed to send offer: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
func (conn *Conn) onGuardEvent() {
|
||||
conn.Log.Debugf("send offer to peer")
|
||||
conn.dumpState.SendOffer()
|
||||
if err := conn.handshaker.SendOffer(); err != nil {
|
||||
conn.Log.Errorf("failed to send offer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err)
|
||||
conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerICEState(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err)
|
||||
conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setStatusToDisconnected() {
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.statusICE.Set(StatusDisconnected)
|
||||
conn.statusRelay.SetDisconnected()
|
||||
conn.statusICE.SetDisconnected()
|
||||
conn.currentConnPriority = conntype.None
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: StatusDisconnected,
|
||||
ConnStatus: StatusIdle,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
@@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
|
||||
if err != nil {
|
||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||
// todo rethink status updates
|
||||
conn.log.Debugf("error while updating peer's state, err: %v", err)
|
||||
conn.Log.Debugf("error while updating peer's state, err: %v", err)
|
||||
}
|
||||
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
|
||||
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
||||
conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -656,32 +674,24 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay, conntype.ICETurn:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.currentConnPriority == connPriorityICEP2P {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) evalStatus() ConnStatus {
|
||||
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
|
||||
if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
|
||||
return StatusConnected
|
||||
}
|
||||
|
||||
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
return StatusDisconnected
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
defer func() {
|
||||
if !connected {
|
||||
@@ -689,12 +699,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.statusICE.Get() == StatusDisconnected {
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() != StatusConnected {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -716,7 +726,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDRelay); err != nil {
|
||||
conn.log.Errorf("After remove peer hook failed: %v", err)
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDRelay = ""
|
||||
@@ -725,7 +735,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDICE != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDICE); err != nil {
|
||||
conn.log.Errorf("After remove peer hook failed: %v", err)
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDICE = ""
|
||||
@@ -733,7 +743,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
|
||||
Port: conn.config.WgConfig.WgListenPort,
|
||||
@@ -741,18 +751,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
|
||||
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
}
|
||||
|
||||
func (conn *Conn) isICEActive() bool {
|
||||
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
|
||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
@@ -760,10 +770,10 @@ func (conn *Conn) removeWgPeer() error {
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
}
|
||||
}
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -773,16 +783,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
|
||||
func (conn *Conn) logTraceConnState() {
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
} else {
|
||||
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
||||
conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = proxy
|
||||
@@ -793,6 +803,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
|
||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||
}
|
||||
|
||||
func (conn *Conn) AgentVersionString() string {
|
||||
return conn.config.AgentVersion
|
||||
}
|
||||
|
||||
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
if conn.config.RosenpassConfig.PubKey == nil {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
@@ -804,7 +818,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
|
||||
@@ -1,58 +1,29 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
StatusConnected ConnStatus = iota
|
||||
// StatusIdle indicate the peer is in disconnected state
|
||||
StatusIdle ConnStatus = iota
|
||||
// StatusConnecting indicate the peer is in connecting state
|
||||
StatusConnecting
|
||||
// StatusDisconnected indicate the peer is in disconnected state
|
||||
StatusDisconnected
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
||||
type AtomicConnStatus struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
||||
func NewAtomicConnStatus() *AtomicConnStatus {
|
||||
acs := &AtomicConnStatus{}
|
||||
acs.Set(StatusDisconnected)
|
||||
return acs
|
||||
}
|
||||
|
||||
// Get returns the current connection status
|
||||
func (acs *AtomicConnStatus) Get() ConnStatus {
|
||||
return ConnStatus(acs.status.Load())
|
||||
}
|
||||
|
||||
// Set updates the connection status
|
||||
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
||||
acs.status.Store(int32(status))
|
||||
}
|
||||
|
||||
// String returns the string representation of the current status
|
||||
func (acs *AtomicConnStatus) String() string {
|
||||
return acs.Get().String()
|
||||
}
|
||||
|
||||
func (s ConnStatus) String() string {
|
||||
switch s {
|
||||
case StatusConnecting:
|
||||
return "Connecting"
|
||||
case StatusConnected:
|
||||
return "Connected"
|
||||
case StatusDisconnected:
|
||||
return "Disconnected"
|
||||
case StatusIdle:
|
||||
return "Idle"
|
||||
default:
|
||||
log.Errorf("unknown status: %d", s)
|
||||
return "INVALID_PEER_CONNECTION_STATUS"
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, "Connected"},
|
||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
||||
{"StatusIdle", StatusIdle, "Idle"},
|
||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||
}
|
||||
|
||||
@@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
|
||||
var connConf = ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
@@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
|
||||
sd := ServiceDependencies{
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables := []struct {
|
||||
name string
|
||||
statusIce ConnStatus
|
||||
statusRelay ConnStatus
|
||||
want ConnStatus
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
|
||||
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
|
||||
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
|
||||
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
|
||||
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
|
||||
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
|
||||
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
t.Run(table.name, func(t *testing.T) {
|
||||
si := NewAtomicConnStatus()
|
||||
si.Set(table.statusIce)
|
||||
conn.statusICE = si
|
||||
|
||||
sr := NewAtomicConnStatus()
|
||||
sr.Set(table.statusRelay)
|
||||
conn.statusRelay = sr
|
||||
|
||||
got := conn.Status()
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_presharedKey(t *testing.T) {
|
||||
conn1 := Conn{
|
||||
|
||||
29
client/internal/peer/conntype/priority.go
Normal file
29
client/internal/peer/conntype/priority.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package conntype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
None ConnPriority = 0
|
||||
Relay ConnPriority = 1
|
||||
ICETurn ConnPriority = 2
|
||||
ICEP2P ConnPriority = 3
|
||||
)
|
||||
|
||||
type ConnPriority int
|
||||
|
||||
func (cp ConnPriority) String() string {
|
||||
switch cp {
|
||||
case None:
|
||||
return "None"
|
||||
case Relay:
|
||||
return "PriorityRelay"
|
||||
case ICETurn:
|
||||
return "PriorityICETurn"
|
||||
case ICEP2P:
|
||||
return "PriorityICEP2P"
|
||||
default:
|
||||
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||
}
|
||||
}
|
||||
52
client/internal/peer/dispatcher/dispatcher.go
Normal file
52
client/internal/peer/dispatcher/dispatcher.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type ConnectionListener struct {
|
||||
OnConnected func(peerID id.ConnID)
|
||||
OnDisconnected func(peerID id.ConnID)
|
||||
}
|
||||
|
||||
type ConnectionDispatcher struct {
|
||||
listeners map[*ConnectionListener]struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectionDispatcher() *ConnectionDispatcher {
|
||||
return &ConnectionDispatcher{
|
||||
listeners: make(map[*ConnectionListener]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.listeners[listener] = struct{}{}
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
delete(e.listeners, listener)
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for listener := range e.listeners {
|
||||
listener.OnConnected(peerConnID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for listener := range e.listeners {
|
||||
listener.OnDisconnected(peerConnID)
|
||||
}
|
||||
}
|
||||
@@ -8,10 +8,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
reconnectMaxElapsedTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
@@ -25,7 +21,6 @@ type isConnectedFunc func() bool
|
||||
type Guard struct {
|
||||
Reconnect chan struct{}
|
||||
log *log.Entry
|
||||
isController bool
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
@@ -33,11 +28,10 @@ type Guard struct {
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
Reconnect: make(chan struct{}, 1),
|
||||
log: log,
|
||||
isController: isController,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
timeout: timeout,
|
||||
srWatcher: srWatcher,
|
||||
@@ -46,12 +40,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) Start(ctx context.Context) {
|
||||
if g.isController {
|
||||
g.reconnectLoopWithRetry(ctx)
|
||||
} else {
|
||||
g.listenForDisconnectEvents(ctx)
|
||||
}
|
||||
func (g *Guard) Start(ctx context.Context, eventCallback func()) {
|
||||
g.reconnectLoopWithRetry(ctx, eventCallback)
|
||||
}
|
||||
|
||||
func (g *Guard) SetRelayedConnDisconnected() {
|
||||
@@ -68,9 +58,9 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
waitForInitialConnectionTry(ctx)
|
||||
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
@@ -93,7 +83,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
g.triggerOfferSending()
|
||||
callback()
|
||||
}
|
||||
|
||||
case <-g.relayedConnDisconnected:
|
||||
@@ -121,39 +111,12 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
|
||||
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
|
||||
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
|
||||
// mean that to switch to it. We always force to use the higher priority connection.
|
||||
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
|
||||
g.log.Infof("start listen for reconnect events...")
|
||||
for {
|
||||
select {
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, triggering reconnect")
|
||||
g.triggerOfferSending()
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE state changed, try to send new offer")
|
||||
g.triggerOfferSending()
|
||||
case <-srReconnectedChan:
|
||||
g.triggerOfferSending()
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
Multiplier: 2,
|
||||
MaxInterval: g.timeout,
|
||||
MaxElapsedTime: reconnectMaxElapsedTime,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
@@ -164,13 +127,6 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
return ticker
|
||||
}
|
||||
|
||||
func (g *Guard) triggerOfferSending() {
|
||||
select {
|
||||
case g.Reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Give chance to the peer to establish the initial connection.
|
||||
// With it, we can decrease to send necessary offer
|
||||
func waitForInitialConnectionTry(ctx context.Context) {
|
||||
|
||||
@@ -43,7 +43,6 @@ type OfferAnswer struct {
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
@@ -57,9 +56,8 @@ type Handshaker struct {
|
||||
remoteAnswerCh chan OfferAnswer
|
||||
}
|
||||
|
||||
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
return &Handshaker{
|
||||
ctx: ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -74,10 +72,10 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen() {
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
h.log.Info("wait for remote offer confirmation")
|
||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
|
||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
|
||||
if err != nil {
|
||||
var connectionClosedError *ConnectionClosedError
|
||||
if errors.As(err, &connectionClosedError) {
|
||||
@@ -127,7 +125,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
||||
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
@@ -137,7 +135,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
||||
return &remoteOfferAnswer, nil
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
return &remoteOfferAnswer, nil
|
||||
case <-h.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
// closed externally
|
||||
return nil, NewConnectionClosedError(h.config.Key)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user