Compare commits

...

38 Commits

Author SHA1 Message Date
Zoltan Papp
aa07b3b87b Fix deadlock (#3904) 2025-05-30 23:38:02 +02:00
Bethuel Mmbaga
2bef214cc0 [management] Fix user groups propagation (#3902) 2025-05-30 18:12:30 +03:00
hakansa
cfb2d82352 [client] Refactor exclude list handling to use a map for permanent connections (#3901)
[client] Refactor exclude list handling to use a map for permanent connections (#3901)
2025-05-30 16:54:49 +03:00
Bethuel Mmbaga
684501fd35 [management] Prevent deletion of peers linked to network routers (#3881)
- Prevent deletion of peers linked to network routers
- Add API endpoint to list all network routers
2025-05-29 18:50:00 +03:00
Zoltan Papp
0492c1724a [client, android] Fix/notifier threading (#3807)
- Fix potential deadlocks
- When adding a listener, immediately notify with the last known IP and fqdn.
2025-05-27 17:12:04 +02:00
Zoltan Papp
6f436e57b5 [server-test] Install libs for i386 tests (#3887)
Install libs for i386 tests
2025-05-27 16:42:06 +02:00
Bethuel Mmbaga
a0d28f9851 [management] Reset test containers after cleanup (#3885) 2025-05-27 14:42:00 +03:00
Zoltan Papp
cdd27a9fe5 [client, android] Fix/android enable server route (#3806)
Enable the server route; otherwise, the manager throws an error and the engine will restart.
2025-05-27 13:32:54 +02:00
Bethuel Mmbaga
5523040acd [management] Add correlated network traffic event schema (#3680) 2025-05-27 13:47:53 +03:00
M. Essam
670446d42e [management/client/rest] Fix panic on unknown errors (#3865) 2025-05-25 16:57:34 +02:00
Pedro Maia Costa
5bed6777d5 [management] force account id on save groups update (#3850) 2025-05-23 14:42:42 +01:00
Pascal Fischer
a0482ebc7b [client] avoid overwriting state manager on iOS (#3870) 2025-05-23 14:04:12 +02:00
Bethuel Mmbaga
2a89d6e47a [management] Extend nameserver match domain validation (#3864)
* Enhance match domain validation logic and add test cases

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove the leading dot and root dot support ns regex

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove support for wildcard ns match domain

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 23:16:19 +02:00
Bethuel Mmbaga
24f932b2ce [management] Update traffic events pagination filters (#3857) 2025-05-22 16:28:14 +03:00
Pedro Maia Costa
c03435061c [management] lazy connection account setting (#3855) 2025-05-22 14:09:00 +01:00
Misha Bragin
8e948739f1 Fix CLA link in the PR template (#3860) 2025-05-22 10:38:58 +02:00
Maycon Santos
9b53cad752 [misc] add CLA note (#3859) 2025-05-21 22:40:36 +02:00
Zoltan Papp
802a18167c [client] Do not reconnect to mgm server in case of handler error (#3856)
* Do not reconnect to mgm server in case of handler error
Set to nil the flow grpc client to nil

* Better error handling
2025-05-21 20:18:21 +02:00
hakansa
e9108ffe6c [client] Add latest gzipped rotated log file to the debug bundle (#3848)
[client] Add latest gzipped rotated log file to the debug bundle
2025-05-21 17:50:54 +03:00
Viktor Liu
e806d9de38 [client] Fix legacy routes when connecting to management servers older than v0.30.0 (#3854) 2025-05-21 13:48:55 +02:00
Zoltan Papp
daa8380df9 [client] Feature/lazy connection (#3379)
With the lazy connection feature, the peer will connect to target peers on-demand. The trigger can be any IP traffic.

This feature can be enabled with the NB_ENABLE_EXPERIMENTAL_LAZY_CONN environment variable.

When the engine receives a network map, it binds a free UDP port for every remote peer, and the system configures WireGuard endpoints for these ports. When traffic appears on a UDP socket, the system removes this listener and starts the peer connection procedure immediately.

Key changes
Fix slow netbird status -d command
Move from engine.go file to conn_mgr.go the peer connection related code
Refactor the iface interface usage and moved interface file next to the engine code
Add new command line flag and UI option to enable feature
The peer.Conn struct is reusable after it has been closed.
Change connection states
Connection states
Idle: The peer is not attempting to establish a connection. This typically means it's in a lazy state or the remote peer is expired.

Connecting: The peer is actively trying to establish a connection. This occurs when the peer has entered an active state and is continuously attempting to reach the remote peer.

Connected: A successful peer-to-peer connection has been established and communication is active.
2025-05-21 11:12:28 +02:00
Bethuel Mmbaga
4785f23fc4 [management] Migrate events sqlite store to gorm (#3837) 2025-05-20 17:00:37 +03:00
Viktor Liu
1d4cfb83e7 [client] Fix UI new version notifier (#3845) 2025-05-20 10:39:17 +02:00
Pascal Fischer
207fa059d2 [management] make locking strength clause optional (#3844) 2025-05-19 16:42:47 +02:00
Viktor Liu
cbcdad7814 [misc] Update issue template (#3842) 2025-05-19 15:41:24 +02:00
Pascal Fischer
701c13807a [management] add flag to disable auto-migration (#3840) 2025-05-19 13:36:24 +02:00
Viktor Liu
99f8dc7748 [client] Offer to remove netbird data in windows uninstall (#3766) 2025-05-16 17:39:30 +02:00
Pascal Fischer
f1de8e6eb0 [management] Make startup period configurable (#3767) 2025-05-16 13:16:51 +02:00
Viktor Liu
b2a10780af [client] Disable dnssec for systemd explicitly (#3831) 2025-05-16 09:43:13 +02:00
Pascal Fischer
43ae79d848 [management] extend rest client lib (#3830) 2025-05-15 18:20:29 +02:00
Pascal Fischer
e520b64c6d [signal] remove stream receive server side (#3820) 2025-05-14 19:28:51 +02:00
hakansa
92c91bbdd8 [client] Add FreeBSD desktop client support to OAuth flow (#3822)
[client] Add FreeBSD desktop client support to OAuth flow
2025-05-14 19:52:02 +03:00
Vlad
adf494e1ac [management] fix a bug with missed extra dns labels for a new peer (#3798) 2025-05-14 17:50:21 +02:00
Vlad
2158461121 [management,client] PKCE add flag parameter prompt=login or max_age (#3824) 2025-05-14 17:48:51 +02:00
Bethuel Mmbaga
0cd4b601c3 [management] Add connection type filter to Network Traffic API (#3815) 2025-05-14 11:15:50 +03:00
Zoltan Papp
ee1cec47b3 [client, android] Do not propagate empty routes (#3805)
If we get domain routes the Network prefix variable in route structure will be invalid (engine.go:1057). When we handower to Android the routes, we must to filter out the domain routes. If we do not do it the Android code will get "invalid prefix" string as a route.
2025-05-13 15:21:06 +02:00
Pascal Fischer
efb0edfc4c [signal] adjust signal log levels 2 (#3817) 2025-05-12 23:52:29 +02:00
Pascal Fischer
20f59ddecb [signal] adjust log levels (#3813) 2025-05-12 19:48:47 +02:00
134 changed files with 5361 additions and 2611 deletions

View File

@@ -37,16 +37,21 @@ If yes, which one?
**Debug output**
To help us resolve the problem, please attach the following debug output
To help us resolve the problem, please attach the following anonymized status output
netbird status -dA
As well as the file created by
Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots**
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -13,3 +13,5 @@
- [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -223,6 +223,10 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -269,6 +273,10 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -179,6 +179,7 @@ jobs:
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules
run: go mod tidy

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"runtime"
"strings"
"time"
@@ -98,11 +99,11 @@ var loginCmd = &cobra.Command{
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil {
return nil, err
}
@@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
}
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@@ -26,22 +26,23 @@ import (
)
const (
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
)
var (
@@ -80,6 +81,7 @@ var (
blockLANAccess bool
debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -184,6 +186,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))

View File

@@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
case "", "idle", "connecting", "connected":
if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
}
if len(ipsFilter) > 0 {

View File

@@ -194,6 +194,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
@@ -262,17 +266,17 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -332,6 +336,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}
var loginErr error
var loginResp *proto.LoginResponse

View File

@@ -201,14 +201,30 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() {
}
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
return nil, fmt.Errorf("wgctl: %w", err)
}
return WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}, nil
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
}

View File

@@ -1,6 +1,7 @@
package configurer
import (
"encoding/base64"
"encoding/hex"
"fmt"
"net"
@@ -17,6 +18,13 @@ import (
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
@@ -217,91 +225,75 @@ func (t *WGUSPConfigurer) Close() {
}
}
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err)
return nil, fmt.Errorf("ipc get: %w", err)
}
stats, err := findPeerInfo(ipc, peerKey, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
return parseTransfers(ipc)
}
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return nil, fmt.Errorf("parse key: %w", err)
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
func parseTransfers(ipc string) (map[string]WGStats, error) {
stats := make(map[string]WGStats)
var (
currentKey string
currentStats WGStats
hasPeer bool
)
lines := strings.Split(ipc, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer {
break
}
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
for _, key := range searchConfigKeys {
if foundPeer && strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
configFound[v[0]] = v[1]
if strings.HasPrefix(line, "public_key=") {
peerID := strings.TrimPrefix(line, "public_key=")
h, err := hex.DecodeString(peerID)
if err != nil {
return nil, fmt.Errorf("decode peerID: %w", err)
}
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
}
if !hasPeer {
continue
}
key := strings.SplitN(line, "=", 2)
if len(key) != 2 {
continue
}
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
}
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
}
}
// todo: use multierr
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
return stats, nil
}
func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -355,6 +347,18 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String()
}
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark

View File

@@ -2,10 +2,8 @@ package configurer
import (
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -34,58 +32,35 @@ errno=0
`
func Test_findPeerInfo(t *testing.T) {
func Test_parseTransfers(t *testing.T) {
tests := []struct {
name string
peerKey string
searchKeys []string
want map[string]string
wantErr bool
name string
peerKey string
want WGStats
}{
{
name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
name: "single",
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
want: WGStats{
TxBytes: 0,
RxBytes: 0,
},
wantErr: false,
},
{
name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
"rx_bytes": "2224",
name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
want: WGStats{
TxBytes: 38333,
RxBytes: 2224,
},
wantErr: false,
},
{
name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "1212111",
"rx_bytes": "1929999999",
name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
want: WGStats{
TxBytes: 1212111,
RxBytes: 1929999999,
},
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
},
}
for _, tt := range tests {
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
key, err := wgtypes.NewKey(res)
require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
stats, err := parseTransfers(ipcFixture)
if err != nil {
require.NoError(t, err)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
})
}
}

View File

@@ -16,5 +16,5 @@ type WGConfigurer interface {
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
GetStats() (map[string]configurer.WGStats, error)
}

View File

@@ -212,9 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey)
// GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats()
}
func (w *WGIface) waitUntilRemoved() error {

View File

@@ -24,6 +24,8 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True
######################################################################
@@ -49,6 +51,10 @@ ShowInstDetails Show
######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -58,9 +64,6 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING
!define MUI_UNABORTWARNING
@@ -70,13 +73,16 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
Var AutostartCheckbox
Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
######################################################################
; Function to create the autostart options page
@@ -104,8 +114,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
${NSD_Check} $AutostartCheckbox
StrCpy $AutostartEnabled "1"
nsDialogs::Show
FunctionEnd
@@ -115,6 +125,30 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand
Exch $1
Push $2
@@ -176,10 +210,10 @@ ${EndIf}
FunctionEnd
######################################################################
Section -MainProgram
${INSTALL_TYPE}
# SetOverwrite ifnewer
SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\"
${INSTALL_TYPE}
# SetOverwrite ifnewer
SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd
######################################################################
@@ -225,31 +259,58 @@ SectionEnd
Section Uninstall
${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd

View File

@@ -76,12 +76,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)

View File

@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}

View File

@@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
}
if !p.providerConfig.DisablePromptLogin {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -7,15 +7,36 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
)
func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct {
name string
prompt bool
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{
{"PromptLogin", true},
{"NoPromptLogin", false},
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
}
for _, tc := range tt {
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
DisablePromptLogin: !tc.prompt,
LoginFlag: tc.loginFlag,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
if err != nil {
t.Fatalf("Failed to request auth info: %v", err)
}
pattern := "prompt=login"
if tc.prompt {
require.Contains(t, authInfo.VerificationURIComplete, pattern)
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
}
})
}

View File

@@ -74,6 +74,8 @@ type ConfigInput struct {
DisableNotifications *bool
DNSLabels domain.List
LazyConnectionEnabled *bool
}
// Config Configuration type
@@ -138,6 +140,8 @@ type Config struct {
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -524,6 +528,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
return updated, nil
}

303
client/internal/conn_mgr.go Normal file
View File

@@ -0,0 +1,303 @@
package internal
import (
"context"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
// - Managing lazy connections via the lazyConnManager
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
return nil
}
if enabled {
// if the lazy connection manager is already started, do not start it again
if e.lazyConnMgr != nil {
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager(ctx)
} else {
if e.lazyConnMgr == nil {
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
e.closeManager(ctx)
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
continue
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(e.ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
}
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
return true
}
if !e.isStartedWithLazyMgr() {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if !lazyconn.IsSupported(conn.AgentVersionString()) {
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerKey,
AllowedIPs: conn.WgConfig().AllowedIps,
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if excluded {
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
conn.Log.Infof("peer added to lazy conn manager")
return
}
func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn, ok := e.peerStore.Remove(peerKey)
if !ok {
return
}
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
}
e.lazyConnMgr.RemovePeer(peerKey)
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(e.ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
func (e *ConnMgr) Close() {
if !e.isStartedWithLazyMgr() {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
ctx, cancel := context.WithCancel(parentCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(ctx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
if e.lazyConnMgr == nil {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
for _, peerID := range e.peerStore.PeersPubKey() {
e.peerStore.PeerConnOpen(ctx, peerID)
}
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.ctxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {
return nil
}
parsedMinutes, err := strconv.Atoi(envValue)
if err != nil || parsedMinutes <= 0 {
return nil
}
d := time.Duration(parsedMinutes) * time.Minute
return &d
}

View File

@@ -440,7 +440,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
BlockLANAccess: config.BlockLANAccess,
LazyConnectionEnabled: config.LazyConnectionEnabled,
}
if config.PreSharedKey != "" {
@@ -481,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
return signalClient, nil
}
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()

View File

@@ -4,6 +4,7 @@ import (
"archive/zip"
"bufio"
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"fmt"
@@ -376,6 +377,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
}
func (g *BundleGenerator) addProf() (err error) {
@@ -533,6 +535,33 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err)
}
// add latest rotated log file
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if runtime.GOOS == "darwin" {
@@ -563,16 +592,13 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
}
}()
var logReader io.Reader
var logReader io.Reader = logFile
if g.anonymize {
var writer *io.PipeWriter
logReader, writer = io.Pipe()
go anonymizeLog(logFile, writer, g.anonymizer)
} else {
logReader = logFile
}
if err := g.addFileToZip(logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err)
}
@@ -580,6 +606,44 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
return nil
}
// addSingleLogFileGz adds a single gzipped log file to the archive
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
f, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open gz log file %s: %w", targetName, err)
}
defer f.Close()
gzr, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("create gzip reader: %w", err)
}
defer gzr.Close()
var logReader io.Reader = gzr
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(gzr, pw, g.anonymizer)
}
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := io.Copy(gw, logReader); err != nil {
return fmt.Errorf("re-gzip: %w", err)
}
if err := gw.Close(); err != nil {
return fmt.Errorf("close gzip writer: %w", err)
}
if err := g.addFileToZip(&buf, targetName); err != nil {
return fmt.Errorf("add anonymized gz: %w", err)
}
return nil
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,

View File

@@ -30,9 +30,12 @@ const (
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
dnsSecDisabled = "no"
)
type systemdDbusConfigurator struct {
@@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
Family: unix.AF_INET,
Address: ipAs4[:],
}
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
if err != nil {
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
}
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSSEC to 'no': %v", err)
}
var (

View File

@@ -5,7 +5,6 @@ package dns
import (
"net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -18,5 +17,4 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -1,7 +1,6 @@
package dns
import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -13,6 +12,5 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}

View File

@@ -38,6 +38,7 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
@@ -122,6 +123,8 @@ type EngineConfig struct {
DisableFirewall bool
BlockLANAccess bool
LazyConnectionEnabled bool
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -134,6 +137,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer
peerStore *peerstore.Store
connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
@@ -170,7 +175,8 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
statusRecorder *peer.Status
statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager
routeManager routemanager.Manager
@@ -235,6 +241,8 @@ func NewEngine(
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
}
path := statemanager.GetDefaultStatePath()
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
@@ -244,11 +252,9 @@ func NewEngine(
}
}
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
}
if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path)
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
return engine
}
@@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
}
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
@@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
err := e.removeAllPeers()
if err != nil {
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
}
@@ -405,8 +414,7 @@ func (e *Engine) Start() error {
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate()
if err != nil {
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err)
@@ -442,6 +450,11 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
@@ -450,7 +463,6 @@ func (e *Engine) Start() error {
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
return nil
}
@@ -550,6 +562,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey()
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
if !ok {
continue
}
if currentPeer.AgentVersionString() != p.AgentVersion {
modified = append(modified, p)
continue
}
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
if !ok {
continue
@@ -559,8 +581,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
continue
}
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
if err != nil {
if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
}
}
@@ -621,16 +642,11 @@ func (e *Engine) removePeer(peerKey string) error {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
defer func() {
err := e.statusRecorder.RemovePeer(peerKey)
if err != nil {
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
}
}()
e.connMgr.RemovePeerConn(peerKey)
conn, exists := e.peerStore.Remove(peerKey)
if exists {
conn.Close()
err := e.statusRecorder.RemovePeer(peerKey)
if err != nil {
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
}
return nil
}
@@ -952,12 +968,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
if e.firewall != nil {
if localipfw, ok := e.firewall.(localIpUpdater); ok {
if err := localipfw.UpdateLocalIPs(); err != nil {
log.Errorf("failed to update local IPs: %v", err)
}
}
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
// This needs to be toggled before applying routes.
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
@@ -976,7 +1004,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
@@ -1022,6 +1051,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
@@ -1155,7 +1188,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
PubKey: offlinePeer.GetWgPubKey(),
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatus: peer.StatusIdle,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
@@ -1191,12 +1224,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerIPs = append(peerIPs, allowedNetIP)
}
conn, err := e.createPeerConn(peerKey, peerIPs)
conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
}
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
@@ -1205,17 +1243,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
conn.Open()
return nil
}
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{
@@ -1229,11 +1260,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
AgentVersion: agentVersion,
Timeout: timeout,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
RosenpassConfig: peer.RosenpassConfig{
PubKey: e.getRosenpassPubKey(),
Addr: e.getRosenpassAddr(),
@@ -1249,7 +1281,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
},
}
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
return nil, err
}
@@ -1270,7 +1311,7 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn, ok := e.peerStore.PeerConn(msg.Key)
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
@@ -1578,13 +1619,39 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool {
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy()
log.Debugf("management health check: healthy=%t", managementHealthy)
results := append(e.probeSTUNs(), e.probeTURNs()...)
stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
}
e.syncMsgMux.Unlock()
results := e.probeICE(stuns, turns)
e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true
@@ -1596,37 +1663,16 @@ func (e *Engine) RunHealthProbes() bool {
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
for _, key := range e.peerStore.PeersPubKey() {
wgStats, err := e.wgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
continue
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy)
return allHealthy
}
func (e *Engine) probeSTUNs() []relay.ProbeResult {
e.syncMsgMux.Lock()
stuns := slices.Clone(e.STUNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
}
func (e *Engine) probeTURNs() []relay.ProbeResult {
e.syncMsgMux.Lock()
turns := slices.Clone(e.TURNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
)
}
// restartEngine restarts the engine by cancelling the client context
@@ -1813,21 +1859,21 @@ func (e *Engine) Address() (netip.Addr, error) {
return ip.Unmap(), nil
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules")
return nil
return nil, nil
}
if len(rules) == 0 {
if e.ingressGatewayMgr == nil {
return nil
return nil, nil
}
err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil)
return err
return nil, err
}
if e.ingressGatewayMgr == nil {
@@ -1878,7 +1924,35 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
log.Errorf("failed to update forwarding rules: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
return forwardingRules, nberrors.FormatErrorOrNil(merr)
}
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
excludedPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer == "" {
continue
}
if !excludedPeers[r.Peer] {
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers[r.Peer] = true
}
}
for _, r := range rules {
ip := r.TranslatedAddress
for _, p := range peers {
for _, allowedIP := range p.GetAllowedIps() {
if allowedIP != ip.String() {
continue
}
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
excludedPeers[p.GetWgPubKey()] = true
}
}
}
return excludedPeers
}
// isChecksEqual checks if two slices of checks are equal.

View File

@@ -28,8 +28,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -38,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -53,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
@@ -93,7 +93,7 @@ type MockWGIface struct {
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net
@@ -171,8 +171,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
return m.GetStatsFunc()
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
@@ -378,6 +378,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
},
}
},
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return nil
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
@@ -400,6 +403,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
type testCase struct {
name string
@@ -770,6 +775,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() {
exitErr := engine.Stop()
@@ -966,6 +973,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() {
exitErr := engine.Stop()
@@ -1476,7 +1485,7 @@ func getConnectedPeers(e *Engine) int {
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.Status() == peer.StatusConnected {
if conn.IsConnected() {
i++
}
}

View File

@@ -35,6 +35,6 @@ type wgIfaceBase interface {
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error)
GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net
}

View File

@@ -0,0 +1,9 @@
//go:build !linux || android
package activity
import "net"
var (
listenIP = net.IP{127, 0, 0, 1}
)

View File

@@ -0,0 +1,10 @@
//go:build !android
package activity
import "net"
var (
// use this ip to avoid eBPF proxy congestion
listenIP = net.IP{127, 0, 1, 1}
)

View File

@@ -0,0 +1,106 @@
package activity
import (
"fmt"
"net"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface lazyconn.WGIface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener")
} else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
}
break
}
if n < 1 {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
break
}
if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
d.done.Unlock()
}
func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
if err := d.conn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
}
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
log.Errorf("failed to create activity listener on %s: %s", addr, err)
return nil, err
}
return conn, nil
}

View File

@@ -0,0 +1,41 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -0,0 +1,95 @@
package activity
import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface lazyconn.WGIface
peers map[peerid.ConnID]*Listener
done chan struct{}
mu sync.Mutex
}
func NewManager(wgIface lazyconn.WGIface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
done: make(chan struct{}),
}
return m
}
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
listener, ok := m.peers[peerConnID]
if !ok {
return
}
log.Debugf("removing activity listener")
delete(m.peers, peerConnID)
listener.Close()
}
func (m *Manager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
close(m.done)
for peerID, listener := range m.peers {
delete(m.peers, peerID)
listener.Close()
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {
m.mu.Unlock()
return
}
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
}
}

View File

@@ -0,0 +1,162 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
PeerID string
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func TestManager_MonitorPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
}
case <-time.After(1 * time.Second):
}
}
func TestManager_RemovePeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
if err := trigger(addr); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case <-mgr.OnActivityChan:
t.Fatal("should not have active activity")
case <-time.After(1 * time.Second):
}
}
func TestManager_MultiPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
peer2 := &MocPeer{}
peerCfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
Log: log.WithField("peer", "examplePublicKey2"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
for i := 0; i < 2; i++ {
select {
case <-mgr.OnActivityChan:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}
}
func trigger(addr string) error {
// Create a connection to the destination UDP address and port
conn, err := net.Dial("udp", addr)
if err != nil {
return err
}
defer conn.Close()
// Write the bytes to the UDP connection
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,32 @@
/*
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
## Overview
The package includes a `Manager` component responsible for:
- Managing lazy connections activated on-demand
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
- Maintaining a list of excluded peers that should always have permanent connections
- Handling remote peer connection initiatives based on peer signaling
## Thread-Safe Operations
The `Manager` ensures thread safety across multiple operations, categorized by caller:
- **Engine (single goroutine)**:
- `AddPeer`: Adds a peer to the connection manager.
- `RemovePeer`: Removes a peer from the connection manager.
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
- **Connection Dispatcher (any peer routine)**:
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
- **Activity Manager**:
- `onPeerActivity`: Run peer.Open(context).
- **Inactivity Monitor**:
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
*/
package lazyconn

View File

@@ -0,0 +1,26 @@
package lazyconn
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
)
func IsLazyConnEnabledByEnv() bool {
val := os.Getenv(EnvEnableLazyConn)
if val == "" {
return false
}
enabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
return false
}
return enabled
}

View File

@@ -0,0 +1,70 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}

View File

@@ -0,0 +1,156 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -0,0 +1,404 @@
package manager
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
const (
watcherActivity watcherType = iota
watcherInactivity
)
type watcherType int
type managedPeer struct {
peerCfg *lazyconn.PeerConfig
expectedWatcher watcherType
}
type Config struct {
InactivityThreshold *time.Duration
}
// Manager manages lazy connections
// It is responsible for:
// - Managing lazy connections activated on-demand
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
type Manager struct {
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
cancel context.CancelFunc
onInactive chan peerid.ConnID
}
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
onInactive: make(chan peerid.ConnID),
}
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
ctx, m.cancel = context.WithCancel(ctx)
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
}
}
}
// ExcludePeer marks peers for a permanent connection
// It removes peers from the managed list if they are added to the exclude list
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
added := make([]string, 0)
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
for _, peerCfg := range peerConfigs {
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
excludes[peerCfg.PublicKey] = peerCfg
}
// if a peer is newly added to the exclude list, remove from the managed peers list
for pubKey, peerCfg := range excludes {
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
continue
}
added = append(added, pubKey)
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
m.removePeer(pubKey)
}
// if a peer has been removed from exclude list then it should be added to the managed peers
for pubKey, peerCfg := range m.excludes {
if _, stillExcluded := excludes[pubKey]; stillExcluded {
continue
}
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(ctx, peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
}
m.excludes = excludes
return added
}
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
peerCfg.Log.Debugf("adding peer to lazy connection manager")
_, exists := m.excludes[peerCfg.PublicKey]
if exists {
return true, nil
}
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return false, nil
}
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
return false, nil
}
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
for _, cfg := range peerCfg {
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
cfg.Log.Errorf("peer already managed")
continue
}
if err := m.addActivePeer(ctx, cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
}
return nil
}
func (m *Manager) RemovePeer(peerID string) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.removePeer(peerID)
}
// ActivatePeer activates a peer connection when a signal message is received
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, ok := m.managedPeers[peerID]
if !ok {
return false
}
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
return false
}
// signal messages coming continuously after success activation, with this avoid the multiple activation
if mp.expectedWatcher == watcherInactivity {
return false
}
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false
}
mp.peerCfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherInactivity,
}
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
return nil
}
func (m *Manager) removePeer(peerID string) {
cfg, ok := m.managedPeers[peerID]
if !ok {
return
}
cfg.Log.Infof("removing lazy peer")
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
}
func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.cancel()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
log.Infof("lazy connection manager closed")
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherActivity {
mp.peerCfg.Log.Warnf("ignore activity event")
return
}
mp.peerCfg.Log.Infof("detected peer activity")
mp.expectedWatcher = watcherInactivity
mp.peerCfg.Log.Infof("starting inactivity monitor")
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@@ -0,0 +1,16 @@
package lazyconn
import (
"net/netip"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type PeerConfig struct {
PublicKey string
AllowedIPs []netip.Prefix
PeerConnID id.ConnID
Log *log.Entry
}

View File

@@ -0,0 +1,41 @@
package lazyconn
import (
"strings"
"github.com/hashicorp/go-version"
)
var (
minVersion = version.Must(version.NewVersion("0.45.0"))
)
func IsSupported(agentVersion string) bool {
if agentVersion == "development" {
return true
}
// filter out versions like this: a6c5960, a7d5c522, d47be154
if !strings.Contains(agentVersion, ".") {
return false
}
normalizedVersion := normalizeVersion(agentVersion)
inputVer, err := version.NewVersion(normalizedVersion)
if err != nil {
return false
}
return inputVer.GreaterThanOrEqual(minVersion)
}
func normalizeVersion(version string) string {
// Remove prefixes like 'v' or 'a'
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
version = version[1:]
}
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
parts := strings.Split(version, "-")
return parts[0]
}

View File

@@ -0,0 +1,31 @@
package lazyconn
import "testing"
func TestIsSupported(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"development", true},
{"0.45.0", true},
{"v0.45.0", true},
{"0.45.1", true},
{"0.45.1-SNAPSHOT-559e6731", true},
{"v0.45.1-dev", true},
{"a7d5c522", false},
{"0.9.6", false},
{"0.9.6-SNAPSHOT", false},
{"0.9.6-SNAPSHOT-2033650", false},
{"meta_wt_version", false},
{"v0.31.1-dev", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
if got := IsSupported(tt.version); got != tt.want {
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,14 @@
package lazyconn
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}

View File

@@ -123,8 +123,14 @@ func (m *Manager) disableFlow() error {
m.logger.Close()
if m.receiverClient != nil {
return m.receiverClient.Close()
if m.receiverClient == nil {
return nil
}
err := m.receiverClient.Close()
m.receiverClient = nil
if err != nil {
return fmt.Errorf("close: %w", err)
}
return nil

View File

@@ -17,8 +17,12 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
@@ -26,32 +30,20 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case connPriorityNone:
return "None"
case connPriorityRelay:
return "PriorityRelay"
case connPriorityICETurn:
return "PriorityICETurn"
case connPriorityICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}
const (
defaultWgKeepAlive = 25 * time.Second
connPriorityNone ConnPriority = 0
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 2
connPriorityICEP2P ConnPriority = 3
)
type ServiceDependencies struct {
StatusRecorder *Status
Signaler *Signaler
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
}
type WgConfig struct {
WgListenPort int
RemoteKey string
@@ -76,6 +68,8 @@ type ConnConfig struct {
// LocalKey is a public key of a local peer
LocalKey string
AgentVersion string
Timeout time.Duration
WgConfig WgConfig
@@ -89,22 +83,23 @@ type ConnConfig struct {
}
type Conn struct {
log *log.Entry
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
handshaker *Handshaker
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
currentConnPriority ConnPriority
statusRelay *worker.AtomicWorkerStatus
statusICE *worker.AtomicWorkerStatus
currentConnPriority conntype.ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection
workerICE *WorkerICE
@@ -120,9 +115,12 @@ type Conn struct {
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher
wg sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -130,91 +128,101 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
if len(config.WgConfig.AllowedIps) == 0 {
return nil, fmt.Errorf("allowed IPs is empty")
}
ctx, ctxCancel := context.WithCancel(engineCtx)
connLog := log.WithField("peer", config.Key)
var conn = &Conn{
log: connLog,
ctx: ctx,
ctxCancel: ctxCancel,
config: config,
statusRecorder: statusRecorder,
signaler: signaler,
relayManager: relayManager,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
dumpState: newStateDump(config.Key, connLog, statusRecorder),
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
peerConnDispatcher: services.PeerConnDispatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
}
ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen()
go conn.dumpState.Start(ctx)
return conn, nil
}
// Open opens connection to the remote peer
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open() {
conn.semaphore.Add(conn.ctx)
conn.log.Debugf("open connection to peer")
func (conn *Conn) Open(engineCtx context.Context) error {
conn.semaphore.Add(engineCtx)
conn.mu.Lock()
defer conn.mu.Unlock()
conn.opened = true
if conn.opened {
conn.semaphore.Done(engineCtx)
return nil
}
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.handshaker.Listen(conn.ctx)
}()
go conn.dumpState.Start(conn.ctx)
peerState := State{
PubKey: conn.config.Key,
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected,
ConnStatus: StatusConnecting,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
conn.log.Warnf("error while updating the state err: %v", err)
if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
conn.Log.Warnf("error while updating the state err: %v", err)
}
go conn.startHandshakeAndReconnect(conn.ctx)
}
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx)
conn.dumpState.SendOffer()
if err := conn.handshaker.sendOffer(); err != nil {
conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.dumpState.SendOffer()
err := conn.handshaker.sendOffer()
if err != nil {
conn.log.Errorf("failed to send initial offer: %v", err)
}
go conn.guard.Start(ctx)
go conn.listenGuardEvent(ctx)
conn.wg.Add(1)
go func() {
conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.wg.Done()
}()
}()
conn.opened = true
return nil
}
// Close closes this peer Conn issuing a close event to the Conn closeCh
@@ -223,14 +231,14 @@ func (conn *Conn) Close() {
defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock()
conn.log.Infof("close peer connection")
conn.ctxCancel()
if !conn.opened {
conn.log.Debugf("ignore close connection to peer")
conn.Log.Debugf("ignore close connection to peer")
return
}
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher()
conn.workerRelay.CloseConn()
conn.workerICE.Close()
@@ -238,7 +246,7 @@ func (conn *Conn) Close() {
if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn()
if err != nil {
conn.log.Errorf("failed to close wg proxy for relay: %v", err)
conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
}
conn.wgProxyRelay = nil
}
@@ -246,13 +254,13 @@ func (conn *Conn) Close() {
if conn.wgProxyICE != nil {
err := conn.wgProxyICE.CloseConn()
if err != nil {
conn.log.Errorf("failed to close wg proxy for ice: %v", err)
conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
}
conn.wgProxyICE = nil
}
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.freeUpConnID()
@@ -262,14 +270,16 @@ func (conn *Conn) Close() {
}
conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed")
conn.opened = false
conn.wg.Wait()
conn.Log.Infof("peer connection closed")
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.dumpState.RemoteAnswer()
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer)
}
@@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.dumpState.RemoteOffer()
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer)
}
@@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig
}
// Status returns current status of the Conn
func (conn *Conn) Status() ConnStatus {
// IsConnected unit tests only
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.evalStatus()
return conn.currentConnPriority != conntype.None
}
func (conn *Conn) GetKey() string {
return conn.config.Key
}
func (conn *Conn) ConnID() id.ConnID {
return id.ConnID(conn)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return
}
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
conn.log.Errorf("remote ICE connection is nil")
if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
conn.Log.Errorf("remote ICE connection is nil")
return
}
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
// todo consider to remove this check
if conn.currentConnPriority > priority {
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.Set(StatusConnected)
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
return
}
conn.log.Infof("set ICE to active connection")
conn.Log.Infof("set ICE to active connection")
conn.dumpState.P2PConnected()
var (
@@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
conn.dumpState.NewLocalProxy()
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return
}
ep = wgProxy.EndpointAddr()
@@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
}
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher()
@@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return
}
wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected)
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
}
func (conn *Conn) onICEStateDisconnected() {
@@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
return
}
conn.log.Tracef("ICE connection state changed to disconnected")
conn.Log.Tracef("ICE connection state changed to disconnected")
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
// switch back to relay connection
if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection")
conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgWatcherWg.Add(1)
@@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.currentConnPriority = connPriorityRelay
conn.currentConnPriority = conntype.Relay
} else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
conn.currentConnPriority = connPriorityNone
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
}
changed := conn.statusICE.Get() != StatusDisconnected
changed := conn.statusICE.Get() != worker.StatusDisconnected
if changed {
conn.guard.SetICEConnDisconnected()
}
conn.statusICE.Set(StatusDisconnected)
conn.statusICE.SetDisconnected()
peerState := State{
PubKey: conn.config.Key,
@@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
}
}
@@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.dumpState.RelayConnected()
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.dumpState.NewLocalProxy()
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.isICEActive() {
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected)
conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
}
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
conn.Log.Warnf("Failed to close relay connection: %v", err)
}
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
@@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()
conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
func (conn *Conn) onRelayDisconnected() {
@@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
return
}
conn.log.Infof("relay connection is disconnected")
conn.Log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay {
conn.log.Infof("clean up WireGuard config")
if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.currentConnPriority = connPriorityNone
conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
}
if conn.wgProxyRelay != nil {
@@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
conn.wgProxyRelay = nil
}
changed := conn.statusRelay.Get() != StatusDisconnected
changed := conn.statusRelay.Get() != worker.StatusDisconnected
if changed {
conn.guard.SetRelayedConnDisconnected()
}
conn.statusRelay.Set(StatusDisconnected)
conn.statusRelay.SetDisconnected()
peerState := State{
PubKey: conn.config.Key,
@@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
ConnStatusUpdate: time.Now(),
}
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
}
}
func (conn *Conn) listenGuardEvent(ctx context.Context) {
for {
select {
case <-conn.guard.Reconnect:
conn.log.Infof("send offer to peer")
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
func (conn *Conn) onGuardEvent() {
conn.Log.Debugf("send offer to peer")
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
conn.Log.Errorf("failed to send offer: %v", err)
}
}
@@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
if err != nil {
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err)
conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
}
}
@@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
err := conn.statusRecorder.UpdatePeerICEState(peerState)
if err != nil {
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err)
conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
}
}
func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay.Set(StatusDisconnected)
conn.statusICE.Set(StatusDisconnected)
conn.statusRelay.SetDisconnected()
conn.statusICE.SetDisconnected()
conn.currentConnPriority = conntype.None
peerState := State{
PubKey: conn.config.Key,
ConnStatus: StatusDisconnected,
ConnStatus: StatusIdle,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
@@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available.
// todo rethink status updates
conn.log.Debugf("error while updating peer's state, err: %v", err)
conn.Log.Debugf("error while updating peer's state, err: %v", err)
}
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
}
}
@@ -656,32 +674,24 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
}
func (conn *Conn) isRelayed() bool {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn:
return true
default:
return false
}
if conn.currentConnPriority == connPriorityICEP2P {
return false
}
return true
}
func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
return StatusConnected
}
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting
}
return StatusDisconnected
return StatusConnecting
}
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
conn.mu.Lock()
defer conn.mu.Unlock()
// would be better to protect this with a mutex, but it could cause deadlock with Close function
defer func() {
if !connected {
@@ -689,12 +699,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if conn.statusICE.Get() == StatusDisconnected {
if conn.statusICE.Get() == worker.StatusDisconnected {
return false
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() != StatusConnected {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false
}
}
@@ -716,7 +726,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err)
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
@@ -725,7 +735,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err)
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
@@ -733,7 +743,7 @@ func (conn *Conn) freeUpConnID() {
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection")
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort,
@@ -741,18 +751,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return nil, err
}
return wgProxy, nil
}
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
}
func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
}
func (conn *Conn) removeWgPeer() error {
@@ -760,10 +770,10 @@ func (conn *Conn) removeWgPeer() error {
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
}
}
if conn.wgProxyRelay != nil {
@@ -773,16 +783,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
}
}
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
@@ -793,6 +803,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr()
}
func (conn *Conn) AgentVersionString() string {
return conn.config.AgentVersion
}
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
if conn.config.RosenpassConfig.PubKey == nil {
return conn.config.WgConfig.PreSharedKey
@@ -804,7 +818,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey
}

View File

@@ -1,58 +1,29 @@
package peer
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
// StatusConnected indicate the peer is in connected state
StatusConnected ConnStatus = iota
// StatusIdle indicate the peer is in disconnected state
StatusIdle ConnStatus = iota
// StatusConnecting indicate the peer is in connecting state
StatusConnecting
// StatusDisconnected indicate the peer is in disconnected state
StatusDisconnected
// StatusConnected indicate the peer is in connected state
StatusConnected
)
// ConnStatus describe the status of a peer's connection
type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string {
switch s {
case StatusConnecting:
return "Connecting"
case StatusConnected:
return "Connected"
case StatusDisconnected:
return "Disconnected"
case StatusIdle:
return "Idle"
default:
log.Errorf("unknown status: %d", s)
return "INVALID_PEER_CONNECTION_STATUS"

View File

@@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
want string
}{
{"StatusConnected", StatusConnected, "Connected"},
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
{"StatusIdle", StatusIdle, "Idle"},
{"StatusConnecting", StatusConnecting, "Connecting"},
}
@@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
assert.Equal(t, got, table.want, "they should be equal")
})
}
}

View File

@@ -1,7 +1,6 @@
package peer
import (
"context"
"fmt"
"os"
"sync"
@@ -11,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -18,6 +18,8 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
var testDispatcher = dispatcher.NewConnectionDispatcher()
var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
@@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil {
return
}
@@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil {
return
}
@@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil {
return
}
@@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
tables := []struct {
name string
statusIce ConnStatus
statusRelay ConnStatus
want ConnStatus
}{
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
func TestConn_presharedKey(t *testing.T) {
conn1 := Conn{

View File

@@ -0,0 +1,29 @@
package conntype
import (
"fmt"
)
const (
None ConnPriority = 0
Relay ConnPriority = 1
ICETurn ConnPriority = 2
ICEP2P ConnPriority = 3
)
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case None:
return "None"
case Relay:
return "PriorityRelay"
case ICETurn:
return "PriorityICETurn"
case ICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}

View File

@@ -0,0 +1,52 @@
package dispatcher
import (
"sync"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type ConnectionListener struct {
OnConnected func(peerID id.ConnID)
OnDisconnected func(peerID id.ConnID)
}
type ConnectionDispatcher struct {
listeners map[*ConnectionListener]struct{}
mu sync.Mutex
}
func NewConnectionDispatcher() *ConnectionDispatcher {
return &ConnectionDispatcher{
listeners: make(map[*ConnectionListener]struct{}),
}
}
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
e.listeners[listener] = struct{}{}
}
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
delete(e.listeners, listener)
}
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnConnected(peerConnID)
}
}
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnDisconnected(peerConnID)
}
}

View File

@@ -8,10 +8,6 @@ import (
log "github.com/sirupsen/logrus"
)
const (
reconnectMaxElapsedTime = 30 * time.Minute
)
type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic.
@@ -25,7 +21,6 @@ type isConnectedFunc func() bool
type Guard struct {
Reconnect chan struct{}
log *log.Entry
isController bool
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
@@ -33,11 +28,10 @@ type Guard struct {
iCEConnDisconnected chan struct{}
}
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
Reconnect: make(chan struct{}, 1),
log: log,
isController: isController,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
srWatcher: srWatcher,
@@ -46,12 +40,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
}
}
func (g *Guard) Start(ctx context.Context) {
if g.isController {
g.reconnectLoopWithRetry(ctx)
} else {
g.listenForDisconnectEvents(ctx)
}
func (g *Guard) Start(ctx context.Context, eventCallback func()) {
g.reconnectLoopWithRetry(ctx, eventCallback)
}
func (g *Guard) SetRelayedConnDisconnected() {
@@ -68,9 +58,9 @@ func (g *Guard) SetICEConnDisconnected() {
}
}
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener()
@@ -93,7 +83,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
}
if !g.isConnectedOnAllWay() {
g.triggerOfferSending()
callback()
}
case <-g.relayedConnDisconnected:
@@ -121,39 +111,12 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
}
}
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
g.log.Infof("start listen for reconnect events...")
for {
select {
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
g.triggerOfferSending()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
@@ -164,13 +127,6 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker
}
func (g *Guard) triggerOfferSending() {
select {
case g.Reconnect <- struct{}{}:
default:
}
}
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {

View File

@@ -43,7 +43,6 @@ type OfferAnswer struct {
type Handshaker struct {
mu sync.Mutex
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
@@ -57,9 +56,8 @@ type Handshaker struct {
remoteAnswerCh chan OfferAnswer
}
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
return &Handshaker{
ctx: ctx,
log: log,
config: config,
signaler: signaler,
@@ -74,10 +72,10 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
}
func (h *Handshaker) Listen() {
func (h *Handshaker) Listen(ctx context.Context) {
for {
h.log.Info("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
if err != nil {
var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) {
@@ -127,7 +125,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
@@ -137,7 +135,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-h.ctx.Done():
case <-ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
}

View File

@@ -0,0 +1,5 @@
package id
import "unsafe"
type ConnID unsafe.Pointer

View File

@@ -15,7 +15,7 @@ import (
type WGIface interface {
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error)
GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy
Address() wgaddr.Address
}

View File

@@ -18,6 +18,8 @@ type notifier struct {
currentClientState bool
lastNotification int
lastNumberOfPeers int
lastFqdnAddress string
lastIPAddress string
}
func newNotifier() *notifier {
@@ -25,15 +27,22 @@ func newNotifier() *notifier {
}
func (n *notifier) setListener(listener Listener) {
n.serverStateLock.Lock()
lastNotification := n.lastNotification
numOfPeers := n.lastNumberOfPeers
fqdnAddress := n.lastFqdnAddress
address := n.lastIPAddress
n.serverStateLock.Unlock()
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
n.serverStateLock.Lock()
n.notifyListener(listener, n.lastNotification)
listener.OnPeersListChanged(n.lastNumberOfPeers)
n.serverStateLock.Unlock()
n.listener = listener
listener.OnAddressChanged(fqdnAddress, address)
notifyListener(listener, lastNotification)
// run on go routine to avoid on Java layer to call go functions on same thread
go listener.OnPeersListChanged(numOfPeers)
}
func (n *notifier) removeListener() {
@@ -44,41 +53,44 @@ func (n *notifier) removeListener() {
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
calculatedState := n.calculateState(mgmState, signalState)
if !n.isServerStateChanged(calculatedState) {
n.serverStateLock.Unlock()
return
}
n.lastNotification = calculatedState
n.serverStateLock.Unlock()
n.notify(n.lastNotification)
n.notify(calculatedState)
}
func (n *notifier) clientStart() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = true
n.lastNotification = stateConnecting
n.notify(n.lastNotification)
n.serverStateLock.Unlock()
n.notify(stateConnecting)
}
func (n *notifier) clientStop() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = false
n.lastNotification = stateDisconnected
n.notify(n.lastNotification)
n.serverStateLock.Unlock()
n.notify(stateDisconnected)
}
func (n *notifier) clientTearDown() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = false
n.lastNotification = stateDisconnecting
n.notify(n.lastNotification)
n.serverStateLock.Unlock()
n.notify(stateDisconnecting)
}
func (n *notifier) isServerStateChanged(newState int) bool {
@@ -87,26 +99,14 @@ func (n *notifier) isServerStateChanged(newState int) bool {
func (n *notifier) notify(state int) {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
listener := n.listener
n.listenersLock.Unlock()
if listener == nil {
return
}
n.notifyListener(n.listener, state)
}
func (n *notifier) notifyListener(l Listener, state int) {
go func() {
switch state {
case stateDisconnected:
l.OnDisconnected()
case stateConnected:
l.OnConnected()
case stateConnecting:
l.OnConnecting()
case stateDisconnecting:
l.OnDisconnecting()
}
}()
notifyListener(listener, state)
}
func (n *notifier) calculateState(managementConn, signalConn bool) int {
@@ -126,20 +126,48 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
}
func (n *notifier) peerListChanged(numOfPeers int) {
n.serverStateLock.Lock()
n.lastNumberOfPeers = numOfPeers
n.serverStateLock.Unlock()
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
listener := n.listener
n.listenersLock.Unlock()
if listener == nil {
return
}
n.listener.OnPeersListChanged(numOfPeers)
// run on go routine to avoid on Java layer to call go functions on same thread
go listener.OnPeersListChanged(numOfPeers)
}
func (n *notifier) localAddressChanged(fqdn, address string) {
n.serverStateLock.Lock()
n.lastFqdnAddress = fqdn
n.lastIPAddress = address
n.serverStateLock.Unlock()
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
listener := n.listener
n.listenersLock.Unlock()
if listener == nil {
return
}
n.listener.OnAddressChanged(fqdn, address)
listener.OnAddressChanged(fqdn, address)
}
func notifyListener(l Listener, state int) {
switch state {
case stateDisconnected:
l.OnDisconnected()
case stateConnected:
l.OnConnected()
case stateConnecting:
l.OnConnecting()
case stateDisconnecting:
l.OnDisconnecting()
}
}

View File

@@ -135,14 +135,15 @@ type NSGroupState struct {
// FullStatus contains the full state held by the Status instance
type FullStatus struct {
Peers []State
ManagementState ManagementState
SignalState SignalState
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
NumOfForwardingRules int
Peers []State
ManagementState ManagementState
SignalState SignalState
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
NumOfForwardingRules int
LazyConnectionEnabled bool
}
// Status holds a state of peers, signal, management connections and relays
@@ -164,6 +165,7 @@ type Status struct {
rosenpassPermissive bool
nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
lazyConnectionEnabled bool
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -219,7 +221,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
}
// AddPeer adds peer to Daemon status map
func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -229,7 +231,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
}
d.peers[peerPubKey] = State{
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
IP: ip,
ConnStatus: StatusIdle,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
@@ -511,9 +514,9 @@ func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch {
case receivedConnStatus == StatusConnecting:
return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
return curr.IP != ""
default:
return false
@@ -689,6 +692,12 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
d.rosenpassEnabled = rosenpassEnabled
}
func (d *Status) UpdateLazyConnection(enabled bool) {
d.mux.Lock()
defer d.mux.Unlock()
d.lazyConnectionEnabled = enabled
}
// MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock()
@@ -761,6 +770,12 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
}
func (d *Status) GetLazyConnection() bool {
d.mux.Lock()
defer d.mux.Unlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
@@ -872,12 +887,13 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
NumOfForwardingRules: len(d.ForwardingRules()),
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
NumOfForwardingRules: len(d.ForwardingRules()),
LazyConnectionEnabled: d.GetLazyConnection(),
}
d.mux.Lock()

View File

@@ -10,22 +10,24 @@ import (
func TestAddPeer(t *testing.T) {
key := "abc"
ip := "100.108.254.1"
status := NewRecorder("https://mgm")
err := status.AddPeer(key, "abc.netbird")
err := status.AddPeer(key, "abc.netbird", ip)
assert.NoError(t, err, "shouldn't return error")
_, exists := status.peers[key]
assert.True(t, exists, "value was found")
err = status.AddPeer(key, "abc.netbird")
err = status.AddPeer(key, "abc.netbird", ip)
assert.Error(t, err, "should return error on duplicate")
}
func TestGetPeer(t *testing.T) {
key := "abc"
ip := "100.108.254.1"
status := NewRecorder("https://mgm")
err := status.AddPeer(key, "abc.netbird")
err := status.AddPeer(key, "abc.netbird", ip)
assert.NoError(t, err, "shouldn't return error")
peerStatus, err := status.GetPeer(key)

View File

@@ -2,6 +2,7 @@ package peer
import (
"context"
"fmt"
"sync"
"time"
@@ -20,7 +21,7 @@ var (
)
type WGInterfaceStater interface {
GetStats(key string) (configurer.WGStats, error)
GetStats() (map[string]configurer.WGStats, error)
}
type WGWatcher struct {
@@ -146,9 +147,13 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
}
func (w *WGWatcher) wgState() (time.Time, error) {
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
wgStates, err := w.wgIfaceStater.GetStats()
if err != nil {
return time.Time{}, err
}
wgState, ok := wgStates[w.peerKey]
if !ok {
return time.Time{}, fmt.Errorf("peer %s not found in WireGuard endpoints", w.peerKey)
}
return wgState.LastHandshake, nil
}

View File

@@ -11,26 +11,11 @@ import (
)
type MocWgIface struct {
initial bool
lastHandshake time.Time
stop bool
stop bool
}
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
if !m.initial {
m.initial = true
return configurer.WGStats{}, nil
}
if !m.stop {
m.lastHandshake = time.Now()
}
stats := configurer.WGStats{
LastHandshake: m.lastHandshake,
}
return stats, nil
func (m *MocWgIface) GetStats() (map[string]configurer.WGStats, error) {
return map[string]configurer.WGStats{}, nil
}
func (m *MocWgIface) disconnect() {

View File

@@ -0,0 +1,55 @@
package worker
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
StatusDisconnected Status = iota
StatusConnected
)
type Status int32
func (s Status) String() string {
switch s {
case StatusDisconnected:
return "Disconnected"
case StatusConnected:
return "Connected"
default:
log.Errorf("unknown status: %d", s)
return "unknown"
}
}
// AtomicWorkerStatus is a thread-safe wrapper for worker status
type AtomicWorkerStatus struct {
status atomic.Int32
}
func NewAtomicStatus() *AtomicWorkerStatus {
acs := &AtomicWorkerStatus{}
acs.SetDisconnected()
return acs
}
// Get returns the current connection status
func (acs *AtomicWorkerStatus) Get() Status {
return Status(acs.status.Load())
}
func (acs *AtomicWorkerStatus) SetConnected() {
acs.status.Store(int32(StatusConnected))
}
func (acs *AtomicWorkerStatus) SetDisconnected() {
acs.status.Store(int32(StatusDisconnected))
}
// String returns the string representation of the current status
func (acs *AtomicWorkerStatus) String() string {
return acs.Get().String()
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
@@ -397,10 +398,10 @@ func isRelayed(pair *ice.CandidatePair) bool {
return false
}
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
if isRelayed(pair) {
return connPriorityICETurn
return conntype.ICETurn
} else {
return connPriorityICEP2P
return conntype.ICEP2P
}
}

View File

@@ -1,6 +1,7 @@
package peerstore
import (
"context"
"net/netip"
"sync"
@@ -79,6 +80,32 @@ func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
return p, true
}
func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnClose(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
p.Close()
}
func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
@@ -41,6 +42,8 @@ type PKCEAuthProviderConfig struct {
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
@@ -100,6 +103,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}

View File

@@ -3,7 +3,6 @@ package iface
import (
"net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -18,5 +17,4 @@ type wgIfaceBase interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -32,6 +32,10 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String())
}
sort.Strings(nets)

View File

@@ -1,5 +1,3 @@
//go:build !android
package routemanager
import (

View File

@@ -1,27 +0,0 @@
//go:build android
package routemanager
import (
"context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
}
func (r serverRouter) cleanUp() {
}
func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error {
return nil
}
func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

File diff suppressed because it is too large Load Diff

View File

@@ -94,7 +94,7 @@ message LoginRequest {
bytes customDNSAddress = 7;
bool isLinuxDesktopClient = 8;
bool isUnixDesktopClient = 8;
string hostname = 9;
@@ -134,6 +134,7 @@ message LoginRequest {
// omits initialized empty slices due to omitempty tags
bool cleanDNSLabels = 27;
optional bool lazyConnectionEnabled = 28;
}
message LoginResponse {
@@ -274,6 +275,8 @@ message FullStatus {
int32 NumberOfForwardingRules = 8;
repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
}
// Networks

View File

@@ -139,6 +139,7 @@ func (s *Server) Start() error {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
@@ -417,6 +418,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.DisableNotifications = msg.DisableNotifications
}
if msg.LazyConnectionEnabled != nil {
inputConfig.LazyConnectionEnabled = msg.LazyConnectionEnabled
s.latestConfigInput.LazyConnectionEnabled = msg.LazyConnectionEnabled
}
s.mutex.Unlock()
if msg.OptionalPreSharedKey != nil {
@@ -446,7 +452,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting)
if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient)
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err
@@ -804,6 +810,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{

View File

@@ -97,6 +97,7 @@ type OutputOverview struct {
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
Events []SystemEventOutput `json:"events" yaml:"events"`
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
}
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
@@ -136,6 +137,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
}
if anon {
@@ -206,7 +208,7 @@ func mapPeers(
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
continue
}
if isPeerConnected {
@@ -384,6 +386,11 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
}
}
lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true"
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
@@ -405,6 +412,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"Networks: %s\n"+
"Forwarding rules: %d\n"+
"Peers count: %s\n",
@@ -419,6 +427,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
networks,
overview.NumberOfForwardingRules,
peersCountString,
@@ -533,23 +542,13 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString
}
func skipDetailByFilters(
peerState *proto.PeerState,
isConnected bool,
statusFilter string,
prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{},
) bool {
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
statusEval := false
ipEval := false
nameEval := true
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
if !strings.EqualFold(peerStatus, statusFilter) {
statusEval = true
}
}

View File

@@ -383,7 +383,8 @@ func TestParsingToJSON(t *testing.T) {
"error": "timeout"
}
],
"events": []
"events": [],
"lazyConnectionEnabled": false
}`
// @formatter:on
@@ -484,6 +485,7 @@ dnsServers:
enabled: false
error: timeout
events: []
lazyConnectionEnabled: false
`
assert.Equal(t, expectedYAML, yaml)
@@ -548,6 +550,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Lazy connection: false
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
@@ -570,6 +573,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Lazy connection: false
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected

View File

@@ -62,6 +62,8 @@ func main() {
return
}
logFile = file
} else {
_ = util.InitLog("trace", "console")
}
// Create the Fyne application.
@@ -191,6 +193,7 @@ type serviceClient struct {
mAllowSSH *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
mNotifications *systray.MenuItem
mAdvancedSettings *systray.MenuItem
mCreateDebugBundle *systray.MenuItem
@@ -382,12 +385,12 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
s.adminURL = iAdminURL
loginRequest := proto.LoginRequest{
ManagementUrl: iMngURL,
AdminURL: iAdminURL,
IsLinuxDesktopClient: runtime.GOOS == "linux",
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
InterfaceName: &s.iInterfaceName.Text,
WireguardPort: &port,
ManagementUrl: iMngURL,
AdminURL: iAdminURL,
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
InterfaceName: &s.iInterfaceName.Text,
WireguardPort: &port,
}
if s.iPreSharedKey.Text != censoredPreSharedKey {
@@ -414,7 +417,7 @@ func (s *serviceClient) login() error {
}
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux",
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
})
if err != nil {
log.Errorf("login to management URL with: %v", err)
@@ -630,6 +633,7 @@ func (s *serviceClient) onTrayReady() {
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable lazy connection", lazyConnMenuDescr, false)
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
@@ -689,104 +693,114 @@ func (s *serviceClient) onTrayReady() {
go s.eventManager.Start(s.ctx)
go func() {
for {
select {
case <-s.mUp.ClickedCh:
s.mUp.Disable()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mAllowSSH.ClickedCh:
if s.mAllowSSH.Checked() {
s.mAllowSSH.Uncheck()
} else {
s.mAllowSSH.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAutoConnect.ClickedCh:
if s.mAutoConnect.Checked() {
s.mAutoConnect.Uncheck()
} else {
s.mAutoConnect.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mEnableRosenpass.ClickedCh:
if s.mEnableRosenpass.Checked() {
s.mEnableRosenpass.Uncheck()
} else {
s.mEnableRosenpass.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAdvancedSettings.ClickedCh:
s.mAdvancedSettings.Disable()
go func() {
defer s.mAdvancedSettings.Enable()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
case <-s.mCreateDebugBundle.ClickedCh:
s.mCreateDebugBundle.Disable()
go func() {
defer s.mCreateDebugBundle.Enable()
s.runSelfCommand("debug", "true")
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
return
case <-s.mGitHub.ClickedCh:
err := openURL("https://github.com/netbirdio/netbird")
if err != nil {
log.Errorf("%s", err)
}
case <-s.mUpdate.ClickedCh:
err := openURL(version.DownloadUrl())
if err != nil {
log.Errorf("%s", err)
}
case <-s.mNetworks.ClickedCh:
s.mNetworks.Disable()
go func() {
defer s.mNetworks.Enable()
s.runSelfCommand("networks", "true")
}()
case <-s.mNotifications.ClickedCh:
if s.mNotifications.Checked() {
s.mNotifications.Uncheck()
} else {
s.mNotifications.Check()
}
if s.eventManager != nil {
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
}
go s.listenEvents()
}
func (s *serviceClient) listenEvents() {
for {
select {
case <-s.mUp.ClickedCh:
s.mUp.Disable()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
return
}
}()
case <-s.mAllowSSH.ClickedCh:
if s.mAllowSSH.Checked() {
s.mAllowSSH.Uncheck()
} else {
s.mAllowSSH.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAutoConnect.ClickedCh:
if s.mAutoConnect.Checked() {
s.mAutoConnect.Uncheck()
} else {
s.mAutoConnect.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mEnableRosenpass.ClickedCh:
if s.mEnableRosenpass.Checked() {
s.mEnableRosenpass.Uncheck()
} else {
s.mEnableRosenpass.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mLazyConnEnabled.ClickedCh:
if s.mLazyConnEnabled.Checked() {
s.mLazyConnEnabled.Uncheck()
} else {
s.mLazyConnEnabled.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAdvancedSettings.ClickedCh:
s.mAdvancedSettings.Disable()
go func() {
defer s.mAdvancedSettings.Enable()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
case <-s.mCreateDebugBundle.ClickedCh:
s.mCreateDebugBundle.Disable()
go func() {
defer s.mCreateDebugBundle.Enable()
s.runSelfCommand("debug", "true")
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
return
case <-s.mGitHub.ClickedCh:
err := openURL("https://github.com/netbirdio/netbird")
if err != nil {
log.Errorf("%s", err)
}
case <-s.mUpdate.ClickedCh:
err := openURL(version.DownloadUrl())
if err != nil {
log.Errorf("%s", err)
}
case <-s.mNetworks.ClickedCh:
s.mNetworks.Disable()
go func() {
defer s.mNetworks.Enable()
s.runSelfCommand("networks", "true")
}()
case <-s.mNotifications.ClickedCh:
if s.mNotifications.Checked() {
s.mNotifications.Uncheck()
} else {
s.mNotifications.Check()
}
if s.eventManager != nil {
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
}
}()
}
}
func (s *serviceClient) runSelfCommand(command, arg string) {
@@ -1018,13 +1032,15 @@ func (s *serviceClient) updateConfig() error {
sshAllowed := s.mAllowSSH.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
notificationsDisabled := !s.mNotifications.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
loginRequest := proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux",
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled,
DisableAutoConnect: &disableAutoStart,
DisableNotifications: &notificationsDisabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
}
if err := s.restartClient(&loginRequest); err != nil {

View File

@@ -5,6 +5,7 @@ const (
allowSSHMenuDescr = "Allow SSH connections"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connect"
notificationsMenuDescr = "Enable notifications"
advancedSettingsMenuDescr = "Advanced settings of the application"
debugBundleMenuDescr = "Create and open debug information bundle"

4
go.mod
View File

@@ -59,13 +59,12 @@ require (
github.com/hashicorp/go-version v1.6.0
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
github.com/mattn/go-sqlite3 v1.14.22
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -195,6 +194,7 @@ require (
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
github.com/mholt/acmez/v2 v2.0.1 // indirect

4
go.sum
View File

@@ -507,8 +507,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=

View File

@@ -59,6 +59,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard
@@ -122,6 +123,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
export NETBIRD_AUTH_PKCE_LOGIN_FLAG
export NETBIRD_AUTH_PKCE_AUDIENCE
export NETBIRD_DASH_AUTH_USE_AUDIENCE
export NETBIRD_DASH_AUTH_AUDIENCE

View File

@@ -95,7 +95,8 @@
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
}
}
}

View File

@@ -28,3 +28,4 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
NETBIRD_RELAY_PORT=33445
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0

View File

@@ -0,0 +1,19 @@
package common
// LoginFlag introduces additional login flags to the PKCE authorization request
type LoginFlag uint8
const (
// LoginFlagPrompt adds prompt=login to the authorization request
LoginFlagPrompt LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0
)
func (l LoginFlag) IsPromptLogin() bool {
return l == LoginFlagPrompt
}
func (l LoginFlag) IsMaxAge0Login() bool {
return l == LoginFlagMaxAge0
}

View File

@@ -260,8 +260,6 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
if err := msgHandler(decryptedResp); err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
// hide any grpc error code that is not relevant for management
return fmt.Errorf("msg handler error: %v", err.Error())
}
}
}

View File

@@ -16,7 +16,7 @@ type AccountsAPI struct {
// List list all accounts, only returns one account always
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/accounts", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil)
if err != nil {
return nil, err
}
@@ -34,7 +34,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
// Delete delete account
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
if err != nil {
return err
}

View File

@@ -66,6 +66,15 @@ func TestAccounts_List_Err(t *testing.T) {
})
}
func TestAccounts_List_ConnErr(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
ret, err := c.Accounts.List(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "404")
assert.Empty(t, ret)
})
}
func TestAccounts_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -14,6 +15,7 @@ import (
type Client struct {
managementURL string
authHeader string
httpClient HttpClient
// Accounts NetBird account APIs
// see more: https://docs.netbird.io/api/resources/accounts
@@ -70,20 +72,29 @@ type Client struct {
// New initialize new Client instance using PAT token
func New(managementURL, token string) *Client {
client := &Client{
managementURL: managementURL,
authHeader: "Token " + token,
}
client.initialize()
return client
return NewWithOptions(
WithManagementURL(managementURL),
WithPAT(token),
)
}
// NewWithBearerToken initialize new Client instance using Bearer token type
func NewWithBearerToken(managementURL, token string) *Client {
return NewWithOptions(
WithManagementURL(managementURL),
WithBearerToken(token),
)
}
func NewWithOptions(opts ...option) *Client {
client := &Client{
managementURL: managementURL,
authHeader: "Bearer " + token,
httpClient: http.DefaultClient,
}
for _, option := range opts {
option(client)
}
client.initialize()
return client
}
@@ -104,7 +115,7 @@ func (c *Client) initialize() {
c.Events = &EventsAPI{c}
}
func (c *Client) newRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
if err != nil {
return nil, err
@@ -116,7 +127,7 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
req.Header.Add("Content-Type", "application/json")
}
resp, err := http.DefaultClient.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
@@ -124,7 +135,8 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
if resp.StatusCode > 299 {
parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
if pErr != nil {
return nil, err
return nil, pErr
}
return nil, errors.New(parsedErr.Message)
}
@@ -135,13 +147,16 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
func parseResponse[T any](resp *http.Response) (T, error) {
var ret T
if resp.Body == nil {
return ret, errors.New("No body")
return ret, fmt.Errorf("Body missing, HTTP Error code %d", resp.StatusCode)
}
bs, err := io.ReadAll(resp.Body)
if err != nil {
return ret, err
}
err = json.Unmarshal(bs, &ret)
if err != nil {
return ret, fmt.Errorf("Error code %d, error unmarshalling body: %w", resp.StatusCode, err)
}
return ret, err
return ret, nil
}

View File

@@ -16,7 +16,7 @@ type DNSAPI struct {
// ListNameserverGroups list all nameserver groups
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
// GetNameserverGroup get nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
// DeleteNameserverGroup delete nameserver group
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
if err != nil {
return err
}
@@ -94,7 +94,7 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
// GetSettings get DNS settings
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/settings", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil)
if err != nil {
return nil, err
}
@@ -112,7 +112,7 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,7 @@ type EventsAPI struct {
// List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/events", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil)
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,7 @@ type GeoLocationAPI struct {
// ListCountries list all country codes
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil)
if err != nil {
return nil, err
}
@@ -28,7 +28,7 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
// ListCountryCities Get a list of all English city names for a given country code
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
if err != nil {
return nil, err
}

View File

@@ -16,7 +16,7 @@ type GroupsAPI struct {
// List list all groups
// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/groups", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
// Get get group info
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/groups/"+groupID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
// Delete delete group
// See more: https://docs.netbird.io/api/resources/groups#delete-a-group
func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/groups/"+groupID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type NetworksAPI struct {
// List list all networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-networks
func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/networks", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
// Get get network info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network
func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+networkID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
// Delete delete network
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network
func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/networks/"+networkID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil)
if err != nil {
return err
}
@@ -108,7 +108,7 @@ func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI {
// List list all resources in networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources
func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil)
if err != nil {
return nil, err
}
@@ -122,7 +122,7 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource,
// Get get network resource info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource
func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
if err != nil {
return nil, err
}
@@ -140,7 +140,7 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -158,7 +158,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -172,7 +172,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
// Delete delete network resource
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource
func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
if err != nil {
return err
}
@@ -200,7 +200,7 @@ func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI {
// List list all routers in networks
// See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers
func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil)
if err != nil {
return nil, err
}
@@ -214,7 +214,7 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro
// Get get network router info
// See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router
func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
if err != nil {
return nil, err
}
@@ -232,7 +232,7 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -250,7 +250,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -264,7 +264,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
// Delete delete network router
// See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router
func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
if err != nil {
return err
}

View File

@@ -0,0 +1,35 @@
package rest
import "net/http"
type option func(*Client)
type HttpClient interface {
Do(req *http.Request) (*http.Response, error)
}
func WithHttpClient(client HttpClient) option {
return func(c *Client) {
c.httpClient = client
}
}
func WithBearerToken(token string) option {
return WithAuthHeader("Bearer " + token)
}
func WithPAT(token string) option {
return WithAuthHeader("Token " + token)
}
func WithManagementURL(url string) option {
return func(c *Client) {
c.managementURL = url
}
}
func WithAuthHeader(value string) option {
return func(c *Client) {
c.authHeader = value
}
}

View File

@@ -16,7 +16,7 @@ type PeersAPI struct {
// List list all peers
// See more: https://docs.netbird.io/api/resources/peers#list-all-peers
func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/peers", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
// Get retrieve a peer
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer
func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/peers/"+peerID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
// Delete delete a peer
// See more: https://docs.netbird.io/api/resources/peers#delete-a-peer
func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/peers/"+peerID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil)
if err != nil {
return err
}
@@ -76,7 +76,7 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
// ListAccessiblePeers list all peers that the specified peer can connect to within the network
// See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers
func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil)
if err != nil {
return nil, err
}

View File

@@ -16,7 +16,9 @@ type PoliciesAPI struct {
// List list all policies
// See more: https://docs.netbird.io/api/resources/policies#list-all-policies
func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/policies", nil)
path := "/api/policies"
resp, err := a.c.NewRequest(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
@@ -30,7 +32,7 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
// Get get policy info
// See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy
func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/policies/"+policyID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +50,7 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -62,11 +64,13 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
// Update update policy info
// See more: https://docs.netbird.io/api/resources/policies#update-a-policy
func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.PutApiPoliciesPolicyIdJSONRequestBody) (*api.Policy, error) {
path := "/api/policies/" + policyID
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/policies/"+policyID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +84,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P
// Delete delete policy
// See more: https://docs.netbird.io/api/resources/policies#delete-a-policy
func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/policies/"+policyID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type PostureChecksAPI struct {
// List list all posture checks
// See more: https://docs.netbird.io/api/resources/posture-checks#list-all-posture-checks
func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/posture-checks", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error)
// Get get posture check info
// See more: https://docs.netbird.io/api/resources/posture-checks#retrieve-a-posture-check
func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api.PostureCheck, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re
// Delete delete posture check
// See more: https://docs.netbird.io/api/resources/posture-checks#delete-a-posture-check
func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type RoutesAPI struct {
// List list all routes
// See more: https://docs.netbird.io/api/resources/routes#list-all-routes
func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/routes", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
// Get get route info
// See more: https://docs.netbird.io/api/resources/routes#retrieve-a-route
func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/routes/"+routeID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +80,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA
// Delete delete route
// See more: https://docs.netbird.io/api/resources/routes#delete-a-route
func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/routes/"+routeID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type SetupKeysAPI struct {
// List list all setup keys
// See more: https://docs.netbird.io/api/resources/setup-keys#list-all-setup-keys
func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/setup-keys", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
// Get get setup key info
// See more: https://docs.netbird.io/api/resources/setup-keys#retrieve-a-setup-key
func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKey, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil)
if err != nil {
return nil, err
}
@@ -44,11 +44,13 @@ func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKe
// Create generate new Setup Key
// See more: https://docs.netbird.io/api/resources/setup-keys#create-a-setup-key
func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJSONRequestBody) (*api.SetupKeyClear, error) {
path := "/api/setup-keys"
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/setup-keys", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +68,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -80,7 +82,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap
// Delete delete setup key
// See more: https://docs.netbird.io/api/resources/setup-keys#delete-a-setup-key
func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type TokensAPI struct {
// List list user tokens
// See more: https://docs.netbird.io/api/resources/tokens#list-all-tokens
func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAccessToken, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce
// Get get user token info
// See more: https://docs.netbird.io/api/resources/tokens#retrieve-a-token
func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.PersonalAccessToken, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil)
if err != nil {
return nil, err
}
@@ -48,7 +48,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA
// Delete delete user token
// See more: https://docs.netbird.io/api/resources/tokens#delete-a-token
func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ type UsersAPI struct {
// List list all users, only returns one user always
// See more: https://docs.netbird.io/api/resources/users#list-all-users
func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/users", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil)
if err != nil {
return nil, err
}
@@ -34,7 +34,7 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -52,7 +52,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi
if err != nil {
return nil, err
}
resp, err := a.c.newRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes))
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes))
if err != nil {
return nil, err
}
@@ -66,7 +66,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi
// Delete delete user
// See more: https://docs.netbird.io/api/resources/users#delete-a-user
func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/users/"+userID, nil)
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil)
if err != nil {
return err
}
@@ -80,7 +80,7 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
// ResendInvitation resend user invitation
// See more: https://docs.netbird.io/api/resources/users#resend-user-invitation
func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
resp, err := a.c.newRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil)
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil)
if err != nil {
return err
}
@@ -94,7 +94,7 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
// Current gets the current user info
// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user
func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/users/current", nil)
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil)
if err != nil {
return nil, err
}

View File

@@ -159,7 +159,7 @@ var (
if err != nil {
return err
}
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}

View File

@@ -8,6 +8,8 @@ import (
const maxDomains = 32
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
@@ -17,8 +19,6 @@ func ValidateDomains(domains []string) (List, error) {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
var domainList List
for _, d := range domains {
@@ -37,27 +37,20 @@ func ValidateDomains(domains []string) (List, error) {
return domainList, nil
}
// ValidateDomainsStrSlice checks if each domain in the list is valid
func ValidateDomainsStrSlice(domains []string) ([]string, error) {
// ValidateDomainsList checks if each domain in the list is valid
func ValidateDomainsList(domains []string) error {
if len(domains) == 0 {
return nil, nil
return nil
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
var domainList []string
for _, d := range domains {
d := strings.ToLower(d)
if !domainRegex.MatchString(d) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
return fmt.Errorf("invalid domain format: %s", d)
}
domainList = append(domainList, d)
}
return domainList, nil
return nil
}

View File

@@ -97,110 +97,89 @@ func TestValidateDomains(t *testing.T) {
}
}
// TestValidateDomainsStrSlice tests the ValidateDomainsStrSlice function.
func TestValidateDomainsStrSlice(t *testing.T) {
// Generate a slice of valid domains up to maxDomains
func TestValidateDomainsList(t *testing.T) {
validDomains := make([]string, maxDomains)
for i := 0; i < maxDomains; i++ {
for i := range maxDomains {
validDomains[i] = fmt.Sprintf("example%d.com", i)
}
tests := []struct {
name string
domains []string
expected []string
wantErr bool
name string
domains []string
wantErr bool
}{
{
name: "Empty list",
domains: nil,
expected: nil,
wantErr: false,
name: "Empty list",
domains: nil,
wantErr: false,
},
{
name: "Single valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
expected: []string{"sub.ex-ample.com"},
wantErr: false,
name: "Single valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"},
expected: []string{"_jabber._tcp.gmail.com"},
wantErr: false,
name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
// Unlike ValidateDomains (which converts to punycode),
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
name: "Unicode domain fails (no punycode conversion)",
domains: []string{"münchen.de"},
expected: nil,
wantErr: true,
name: "Unicode domain fails (no punycode conversion)",
domains: []string{"münchen.de"},
wantErr: true,
},
{
name: "Invalid domain format - leading dash",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
name: "Invalid domain format - leading dash",
domains: []string{"-example.com"},
wantErr: true,
},
{
name: "Invalid domain format - trailing dash",
domains: []string{"example-.com"},
expected: nil,
wantErr: true,
name: "Invalid domain format - trailing dash",
domains: []string{"example-.com"},
wantErr: true,
},
{
// The function stops on the first invalid domain and returns an error,
// so only the first domain is definitely valid, but the second is invalid.
name: "Multiple domains with a valid one, then invalid",
domains: []string{"google.com", "invalid_domain.com-"},
expected: []string{"google.com"},
wantErr: true,
name: "Multiple domains with a valid one, then invalid",
domains: []string{"google.com", "invalid_domain.com-"},
wantErr: true,
},
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
expected: []string{"*.example.com"},
wantErr: false,
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with leading dot - invalid",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
name: "Wildcard with leading dot - invalid",
domains: []string{".*.example.com"},
wantErr: true,
},
{
name: "Invalid wildcard with multiple asterisks",
domains: []string{"a.*.example.com"},
expected: nil,
wantErr: true,
name: "Invalid wildcard with multiple asterisks",
domains: []string{"a.*.example.com"},
wantErr: true,
},
{
name: "Exactly maxDomains items (valid)",
domains: validDomains,
expected: validDomains,
wantErr: false,
name: "Exactly maxDomains items (valid)",
domains: validDomains,
wantErr: false,
},
{
name: "Exceeds maxDomains items",
domains: append(validDomains, "extra.com"),
expected: nil,
wantErr: true,
name: "Exceeds maxDomains items",
domains: append(validDomains, "extra.com"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ValidateDomainsStrSlice(tt.domains)
// Check if we got an error where expected
err := ValidateDomainsList(tt.domains)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Compare the returned domains to what we expect
assert.Equal(t, tt.expected, got)
})
}
}

Some files were not shown because too many files have changed in this diff Show More