mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-06 09:34:05 -04:00
Compare commits
38 Commits
feature/bu
...
v0.45.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa07b3b87b | ||
|
|
2bef214cc0 | ||
|
|
cfb2d82352 | ||
|
|
684501fd35 | ||
|
|
0492c1724a | ||
|
|
6f436e57b5 | ||
|
|
a0d28f9851 | ||
|
|
cdd27a9fe5 | ||
|
|
5523040acd | ||
|
|
670446d42e | ||
|
|
5bed6777d5 | ||
|
|
a0482ebc7b | ||
|
|
2a89d6e47a | ||
|
|
24f932b2ce | ||
|
|
c03435061c | ||
|
|
8e948739f1 | ||
|
|
9b53cad752 | ||
|
|
802a18167c | ||
|
|
e9108ffe6c | ||
|
|
e806d9de38 | ||
|
|
daa8380df9 | ||
|
|
4785f23fc4 | ||
|
|
1d4cfb83e7 | ||
|
|
207fa059d2 | ||
|
|
cbcdad7814 | ||
|
|
701c13807a | ||
|
|
99f8dc7748 | ||
|
|
f1de8e6eb0 | ||
|
|
b2a10780af | ||
|
|
43ae79d848 | ||
|
|
e520b64c6d | ||
|
|
92c91bbdd8 | ||
|
|
adf494e1ac | ||
|
|
2158461121 | ||
|
|
0cd4b601c3 | ||
|
|
ee1cec47b3 | ||
|
|
efb0edfc4c | ||
|
|
20f59ddecb |
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,16 +37,21 @@ If yes, which one?
|
||||
|
||||
**Debug output**
|
||||
|
||||
To help us resolve the problem, please attach the following debug output
|
||||
To help us resolve the problem, please attach the following anonymized status output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
As well as the file created by
|
||||
Create and upload a debug bundle, and share the returned file key:
|
||||
|
||||
netbird debug for 1m -AS -U
|
||||
|
||||
*Uploaded files are automatically deleted after 30 days.*
|
||||
|
||||
|
||||
Alternatively, create the file only and attach it here manually:
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
We advise reviewing the anonymized output for any remaining personal information.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -13,3 +13,5 @@
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] Extended the README / documentation, if necessary
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
8
.github/workflows/golang-test-linux.yml
vendored
8
.github/workflows/golang-test-linux.yml
vendored
@@ -223,6 +223,10 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -269,6 +273,10 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
|
||||
@@ -179,6 +179,7 @@ jobs:
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -98,11 +99,11 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
}
|
||||
}
|
||||
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
@@ -26,22 +26,23 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -80,6 +81,7 @@ var (
|
||||
blockLANAccess bool
|
||||
debugUploadBundle bool
|
||||
debugUploadBundleURL string
|
||||
lazyConnEnabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -184,6 +186,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
|
||||
@@ -44,7 +44,7 @@ func init() {
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
case "", "idle", "connecting", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
|
||||
@@ -194,6 +194,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -262,17 +266,17 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -332,6 +336,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -201,14 +201,30 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
func (c *KernelConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
||||
return nil, fmt.Errorf("wgctl: %w", err)
|
||||
}
|
||||
return WGStats{
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
TxBytes: peer.TransmitBytes,
|
||||
RxBytes: peer.ReceiveBytes,
|
||||
}, nil
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("Got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||
}
|
||||
|
||||
for _, peer := range wgDevice.Peers {
|
||||
stats[peer.PublicKey.String()] = WGStats{
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
TxBytes: peer.TransmitBytes,
|
||||
RxBytes: peer.ReceiveBytes,
|
||||
}
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -17,6 +18,13 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
)
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
|
||||
type WGUSPConfigurer struct {
|
||||
@@ -217,91 +225,75 @@ func (t *WGUSPConfigurer) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
ipc, err := t.device.IpcGet()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
||||
return nil, fmt.Errorf("ipc get: %w", err)
|
||||
}
|
||||
|
||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
||||
"last_handshake_time_sec",
|
||||
"last_handshake_time_nsec",
|
||||
"tx_bytes",
|
||||
"rx_bytes",
|
||||
})
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
||||
}
|
||||
|
||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
||||
}
|
||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
||||
}
|
||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
||||
}
|
||||
|
||||
return WGStats{
|
||||
LastHandshake: time.Unix(sec, nsec),
|
||||
TxBytes: txBytes,
|
||||
RxBytes: rxBytes,
|
||||
}, nil
|
||||
return parseTransfers(ipc)
|
||||
}
|
||||
|
||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
|
||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
||||
|
||||
lines := strings.Split(ipcInput, "\n")
|
||||
|
||||
configFound := map[string]string{}
|
||||
foundPeer := false
|
||||
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
var (
|
||||
currentKey string
|
||||
currentStats WGStats
|
||||
hasPeer bool
|
||||
)
|
||||
lines := strings.Split(ipc, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// If we're within the details of the found peer and encounter another public key,
|
||||
// this means we're starting another peer's details. So, stop.
|
||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
||||
break
|
||||
}
|
||||
|
||||
// Identify the peer with the specific public key
|
||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
||||
foundPeer = true
|
||||
}
|
||||
|
||||
for _, key := range searchConfigKeys {
|
||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
||||
v := strings.SplitN(line, "=", 2)
|
||||
configFound[v[0]] = v[1]
|
||||
if strings.HasPrefix(line, "public_key=") {
|
||||
peerID := strings.TrimPrefix(line, "public_key=")
|
||||
h, err := hex.DecodeString(peerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||
}
|
||||
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||
currentStats = WGStats{} // Reset stats for the new peer
|
||||
hasPeer = true
|
||||
stats[currentKey] = currentStats
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasPeer {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.SplitN(line, "=", 2)
|
||||
if len(key) != 2 {
|
||||
continue
|
||||
}
|
||||
switch key[0] {
|
||||
case ipcKeyLastHandshakeTimeSec:
|
||||
hs, err := toLastHandshake(key[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentStats.LastHandshake = hs
|
||||
stats[currentKey] = currentStats
|
||||
case ipcKeyRxBytes:
|
||||
rxBytes, err := toBytes(key[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||
}
|
||||
currentStats.RxBytes = rxBytes
|
||||
stats[currentKey] = currentStats
|
||||
case ipcKeyTxBytes:
|
||||
TxBytes, err := toBytes(key[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||
}
|
||||
currentStats.TxBytes = TxBytes
|
||||
stats[currentKey] = currentStats
|
||||
}
|
||||
}
|
||||
|
||||
// todo: use multierr
|
||||
for _, key := range searchConfigKeys {
|
||||
if _, ok := configFound[key]; !ok {
|
||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
||||
}
|
||||
}
|
||||
if !foundPeer {
|
||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
||||
}
|
||||
|
||||
return configFound, nil
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
@@ -355,6 +347,18 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
return time.Unix(sec, 0), nil
|
||||
}
|
||||
|
||||
func toBytes(s string) (int64, error) {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.ControlPlaneMark
|
||||
|
||||
@@ -2,10 +2,8 @@ package configurer
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -34,58 +32,35 @@ errno=0
|
||||
|
||||
`
|
||||
|
||||
func Test_findPeerInfo(t *testing.T) {
|
||||
func Test_parseTransfers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerKey string
|
||||
searchKeys []string
|
||||
want map[string]string
|
||||
wantErr bool
|
||||
name string
|
||||
peerKey string
|
||||
want WGStats
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
searchKeys: []string{"tx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "38333",
|
||||
name: "single",
|
||||
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||
want: WGStats{
|
||||
TxBytes: 0,
|
||||
RxBytes: 0,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "38333",
|
||||
"rx_bytes": "2224",
|
||||
name: "multiple",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
want: WGStats{
|
||||
TxBytes: 38333,
|
||||
RxBytes: 2224,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "lastpeer",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "1212111",
|
||||
"rx_bytes": "1929999999",
|
||||
name: "lastpeer",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
want: WGStats{
|
||||
TxBytes: 1212111,
|
||||
RxBytes: 1929999999,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "peer not found",
|
||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
||||
searchKeys: nil,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key not found",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "1212111",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
||||
key, err := wgtypes.NewKey(res)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
||||
stats, err := parseTransfers(ipcFixture)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
stat, ok := stats[key.String()]
|
||||
if !ok {
|
||||
require.True(t, ok)
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.want, stat)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,5 +16,5 @@ type WGConfigurer interface {
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close()
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -212,9 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
return w.tun.Device()
|
||||
}
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return w.configurer.GetStats(peerKey)
|
||||
// GetStats returns the last handshake time, rx and tx bytes
|
||||
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return w.configurer.GetStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) waitUntilRemoved() error {
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
|
||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||
|
||||
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
||||
|
||||
Unicode True
|
||||
|
||||
######################################################################
|
||||
@@ -49,6 +51,10 @@ ShowInstDetails Show
|
||||
|
||||
######################################################################
|
||||
|
||||
!include "MUI2.nsh"
|
||||
!include LogicLib.nsh
|
||||
!include "nsDialogs.nsh"
|
||||
|
||||
!define MUI_ICON "${ICON}"
|
||||
!define MUI_UNICON "${ICON}"
|
||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||
@@ -58,9 +64,6 @@ ShowInstDetails Show
|
||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||
######################################################################
|
||||
|
||||
!include "MUI2.nsh"
|
||||
!include LogicLib.nsh
|
||||
|
||||
!define MUI_ABORTWARNING
|
||||
!define MUI_UNABORTWARNING
|
||||
|
||||
@@ -70,13 +73,16 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_PAGE_DIRECTORY
|
||||
|
||||
; Custom page for autostart checkbox
|
||||
Page custom AutostartPage AutostartPageLeave
|
||||
|
||||
!insertmacro MUI_PAGE_INSTFILES
|
||||
|
||||
!insertmacro MUI_PAGE_FINISH
|
||||
|
||||
!insertmacro MUI_UNPAGE_WELCOME
|
||||
|
||||
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
||||
|
||||
!insertmacro MUI_UNPAGE_CONFIRM
|
||||
|
||||
!insertmacro MUI_UNPAGE_INSTFILES
|
||||
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
|
||||
Var AutostartCheckbox
|
||||
Var AutostartEnabled
|
||||
|
||||
; Variables for uninstall data deletion option
|
||||
Var DeleteDataCheckbox
|
||||
Var DeleteDataEnabled
|
||||
|
||||
######################################################################
|
||||
|
||||
; Function to create the autostart options page
|
||||
@@ -104,8 +114,8 @@ Function AutostartPage
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||
Pop $AutostartCheckbox
|
||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||
${NSD_Check} $AutostartCheckbox
|
||||
StrCpy $AutostartEnabled "1"
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
@@ -115,6 +125,30 @@ Function AutostartPageLeave
|
||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||
FunctionEnd
|
||||
|
||||
; Function to create the uninstall data deletion page
|
||||
Function un.DeleteDataPage
|
||||
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
||||
|
||||
nsDialogs::Create 1018
|
||||
Pop $0
|
||||
|
||||
${If} $0 == error
|
||||
Abort
|
||||
${EndIf}
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
||||
Pop $DeleteDataCheckbox
|
||||
${NSD_Uncheck} $DeleteDataCheckbox
|
||||
StrCpy $DeleteDataEnabled "0"
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
|
||||
; Function to handle leaving the data deletion page
|
||||
Function un.DeleteDataPageLeave
|
||||
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
||||
FunctionEnd
|
||||
|
||||
Function GetAppFromCommand
|
||||
Exch $1
|
||||
Push $2
|
||||
@@ -176,10 +210,10 @@ ${EndIf}
|
||||
FunctionEnd
|
||||
######################################################################
|
||||
Section -MainProgram
|
||||
${INSTALL_TYPE}
|
||||
# SetOverwrite ifnewer
|
||||
SetOutPath "$INSTDIR"
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
${INSTALL_TYPE}
|
||||
# SetOverwrite ifnewer
|
||||
SetOutPath "$INSTDIR"
|
||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||
SectionEnd
|
||||
######################################################################
|
||||
|
||||
@@ -225,31 +259,58 @@ SectionEnd
|
||||
Section Uninstall
|
||||
${INSTALL_TYPE}
|
||||
|
||||
DetailPrint "Stopping Netbird service..."
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||
DetailPrint "Uninstalling Netbird service..."
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
|
||||
# kill ui client
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
${If} $DeleteDataEnabled == "1"
|
||||
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
||||
ClearErrors
|
||||
RMDir /r "${NETBIRD_DATA_DIR}"
|
||||
IfErrors 0 +2 ; If no errors, jump over the message
|
||||
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
||||
DetailPrint "Netbird data directory removal complete."
|
||||
${Else}
|
||||
DetailPrint "User did not opt to delete Netbird data."
|
||||
${EndIf}
|
||||
|
||||
# wait the service uninstall take unblock the executable
|
||||
DetailPrint "Waiting for service handle to be released..."
|
||||
Sleep 3000
|
||||
|
||||
DetailPrint "Deleting application files..."
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
DetailPrint "Removing application directory..."
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
DetailPrint "Removing shortcuts..."
|
||||
SetShellVarContext all
|
||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||
|
||||
DetailPrint "Removing registry keys..."
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||
|
||||
DetailPrint "Removing application directory from PATH..."
|
||||
EnVar::SetHKLM
|
||||
EnVar::DeleteValue "path" "$INSTDIR"
|
||||
|
||||
DetailPrint "Uninstallation finished."
|
||||
SectionEnd
|
||||
|
||||
|
||||
|
||||
@@ -76,12 +76,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
|
||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
|
||||
@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
||||
if runtime.GOOS == "freebsd" {
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
}
|
||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
@@ -7,15 +7,36 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/management/client/common"
|
||||
)
|
||||
|
||||
func TestPromptLogin(t *testing.T) {
|
||||
const (
|
||||
promptLogin = "prompt=login"
|
||||
maxAge0 = "max_age=0"
|
||||
)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
prompt bool
|
||||
name string
|
||||
loginFlag mgm.LoginFlag
|
||||
disablePromptLogin bool
|
||||
expect string
|
||||
}{
|
||||
{"PromptLogin", true},
|
||||
{"NoPromptLogin", false},
|
||||
{
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
expect: promptLogin,
|
||||
},
|
||||
{
|
||||
name: "Max age 0 login",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expect: maxAge0,
|
||||
},
|
||||
{
|
||||
name: "Disable prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
disablePromptLogin: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
DisablePromptLogin: !tc.prompt,
|
||||
LoginFlag: tc.loginFlag,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
pattern := "prompt=login"
|
||||
if tc.prompt {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||
|
||||
if !tc.disablePromptLogin {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||
} else {
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -74,6 +74,8 @@ type ConfigInput struct {
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
LazyConnectionEnabled *bool
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -138,6 +140,8 @@ type Config struct {
|
||||
ClientCertKeyPath string
|
||||
|
||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@@ -524,6 +528,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
303
client/internal/conn_mgr.go
Normal file
303
client/internal/conn_mgr.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
)
|
||||
|
||||
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||
//
|
||||
// The connection manager is responsible for:
|
||||
// - Managing lazy connections via the lazyConnManager
|
||||
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||
// - Handling connection establishment based on peer signaling
|
||||
//
|
||||
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||
type ConnMgr struct {
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
dispatcher *dispatcher.ConnectionDispatcher
|
||||
enabledLocally bool
|
||||
|
||||
lazyConnMgr *manager.Manager
|
||||
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||
e := &ConnMgr{
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
dispatcher: dispatcher,
|
||||
}
|
||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||
e.enabledLocally = true
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||
func (e *ConnMgr) Start(ctx context.Context) {
|
||||
if e.lazyConnMgr != nil {
|
||||
log.Errorf("lazy connection manager is already started")
|
||||
return
|
||||
}
|
||||
|
||||
if !e.enabledLocally {
|
||||
log.Infof("lazy connection manager is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
}
|
||||
|
||||
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
|
||||
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||
// do not disable lazy connection manager if it was enabled by env var
|
||||
if e.enabledLocally {
|
||||
return nil
|
||||
}
|
||||
|
||||
if enabled {
|
||||
// if the lazy connection manager is already started, do not start it again
|
||||
if e.lazyConnMgr != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
return e.addPeersToLazyConnManager(ctx)
|
||||
} else {
|
||||
if e.lazyConnMgr == nil {
|
||||
return nil
|
||||
}
|
||||
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||
e.closeManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
|
||||
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
|
||||
if e.lazyConnMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
|
||||
|
||||
for peerID := range peerIDs {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
lazyPeerCfg := lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||
PeerConnID: peerConn.ConnID(),
|
||||
Log: peerConn.Log,
|
||||
}
|
||||
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||
}
|
||||
|
||||
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
|
||||
for _, peerID := range added {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
|
||||
continue
|
||||
}
|
||||
|
||||
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
|
||||
if err := peerConn.Open(e.ctx); err != nil {
|
||||
peerConn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
|
||||
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
|
||||
return true
|
||||
}
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !lazyconn.IsSupported(conn.AgentVersionString()) {
|
||||
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lazyPeerCfg := lazyconn.PeerConfig{
|
||||
PublicKey: peerKey,
|
||||
AllowedIPs: conn.WgConfig().AllowedIps,
|
||||
PeerConnID: conn.ConnID(),
|
||||
Log: conn.Log,
|
||||
}
|
||||
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if excluded {
|
||||
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Infof("peer added to lazy conn manager")
|
||||
return
|
||||
}
|
||||
|
||||
func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||
conn, ok := e.peerStore.Remove(peerKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
}
|
||||
|
||||
e.lazyConnMgr.RemovePeer(peerKey)
|
||||
conn.Log.Infof("removed peer from lazy conn manager")
|
||||
}
|
||||
|
||||
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return conn, true
|
||||
}
|
||||
|
||||
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
|
||||
conn.Log.Infof("activated peer from inactive state")
|
||||
if err := conn.Open(e.ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
}
|
||||
return conn, true
|
||||
}
|
||||
|
||||
func (e *ConnMgr) Close() {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
}
|
||||
|
||||
e.ctxCancel()
|
||||
e.wg.Wait()
|
||||
e.lazyConnMgr = nil
|
||||
}
|
||||
|
||||
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
|
||||
cfg := manager.Config{
|
||||
InactivityThreshold: inactivityThresholdEnv(),
|
||||
}
|
||||
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
|
||||
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
e.ctx = ctx
|
||||
e.ctxCancel = cancel
|
||||
|
||||
e.wg.Add(1)
|
||||
go func() {
|
||||
defer e.wg.Done()
|
||||
e.lazyConnMgr.Start(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
|
||||
peers := e.peerStore.PeersPubKey()
|
||||
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
|
||||
for _, peerID := range peers {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
lazyPeerCfg := lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||
PeerConnID: peerConn.ConnID(),
|
||||
Log: peerConn.Log,
|
||||
}
|
||||
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||
}
|
||||
|
||||
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
|
||||
}
|
||||
|
||||
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||
if e.lazyConnMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.ctxCancel()
|
||||
e.wg.Wait()
|
||||
e.lazyConnMgr = nil
|
||||
|
||||
for _, peerID := range e.peerStore.PeersPubKey() {
|
||||
e.peerStore.PeerConnOpen(ctx, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||
return e.lazyConnMgr != nil && e.ctxCancel != nil
|
||||
}
|
||||
|
||||
func inactivityThresholdEnv() *time.Duration {
|
||||
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||
if envValue == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedMinutes, err := strconv.Atoi(envValue)
|
||||
if err != nil || parsedMinutes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
d := time.Duration(parsedMinutes) * time.Minute
|
||||
return &d
|
||||
}
|
||||
@@ -440,7 +440,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
DisableDNS: config.DisableDNS,
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
BlockLANAccess: config.BlockLANAccess,
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -481,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
return signalClient, nil
|
||||
}
|
||||
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -376,6 +377,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addProf() (err error) {
|
||||
@@ -533,6 +535,33 @@ func (g *BundleGenerator) addLogfile() error {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
// add latest rotated log file
|
||||
pattern := filepath.Join(logDir, "client-*.log.gz")
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("failed to glob rotated logs: %v", err)
|
||||
} else if len(files) > 0 {
|
||||
// pick the file with the latest ModTime
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
fi, err := os.Stat(files[i])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
|
||||
return false
|
||||
}
|
||||
fj, err := os.Stat(files[j])
|
||||
if err != nil {
|
||||
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
|
||||
return false
|
||||
}
|
||||
return fi.ModTime().Before(fj.ModTime())
|
||||
})
|
||||
latest := files[len(files)-1]
|
||||
name := filepath.Base(latest)
|
||||
if err := g.addSingleLogFileGz(latest, name); err != nil {
|
||||
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
stdErrLogPath := filepath.Join(logDir, errorLogFile)
|
||||
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
||||
if runtime.GOOS == "darwin" {
|
||||
@@ -563,16 +592,13 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
var logReader io.Reader
|
||||
var logReader io.Reader = logFile
|
||||
if g.anonymize {
|
||||
var writer *io.PipeWriter
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go anonymizeLog(logFile, writer, g.anonymizer)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(logReader, targetName); err != nil {
|
||||
return fmt.Errorf("add %s to zip: %w", targetName, err)
|
||||
}
|
||||
@@ -580,6 +606,44 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addSingleLogFileGz adds a single gzipped log file to the archive
|
||||
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
|
||||
f, err := os.Open(logPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open gz log file %s: %w", targetName, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gzr, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
var logReader io.Reader = gzr
|
||||
if g.anonymize {
|
||||
var pw *io.PipeWriter
|
||||
logReader, pw = io.Pipe()
|
||||
go anonymizeLog(gzr, pw, g.anonymizer)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
if _, err := io.Copy(gw, logReader); err != nil {
|
||||
return fmt.Errorf("re-gzip: %w", err)
|
||||
}
|
||||
|
||||
if err := gw.Close(); err != nil {
|
||||
return fmt.Errorf("close gzip writer: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(&buf, targetName); err != nil {
|
||||
return fmt.Errorf("add anonymized gz: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
|
||||
@@ -30,9 +30,12 @@ const (
|
||||
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
|
||||
dnsSecDisabled = "no"
|
||||
)
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
@@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
Family: unix.AF_INET,
|
||||
Address: ipAs4[:],
|
||||
}
|
||||
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
|
||||
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
||||
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
|
||||
}
|
||||
|
||||
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
|
||||
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -5,7 +5,6 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -18,5 +17,4 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -13,6 +12,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
@@ -122,6 +123,8 @@ type EngineConfig struct {
|
||||
DisableFirewall bool
|
||||
|
||||
BlockLANAccess bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -134,6 +137,8 @@ type Engine struct {
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerStore *peerstore.Store
|
||||
|
||||
connMgr *ConnMgr
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
|
||||
@@ -170,7 +175,8 @@ type Engine struct {
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
|
||||
statusRecorder *peer.Status
|
||||
statusRecorder *peer.Status
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
|
||||
firewall firewallManager.Manager
|
||||
routeManager routemanager.Manager
|
||||
@@ -235,6 +241,8 @@ func NewEngine(
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
}
|
||||
|
||||
path := statemanager.GetDefaultStatePath()
|
||||
if runtime.GOOS == "ios" {
|
||||
if !fileExists(mobileDep.StateFilePath) {
|
||||
err := createFile(mobileDep.StateFilePath)
|
||||
@@ -244,11 +252,9 @@ func NewEngine(
|
||||
}
|
||||
}
|
||||
|
||||
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
|
||||
}
|
||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||
engine.stateManager = statemanager.New(path)
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
|
||||
return engine
|
||||
}
|
||||
@@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
|
||||
// stopping network monitor first to avoid starting the engine again
|
||||
if e.networkMonitor != nil {
|
||||
e.networkMonitor.Stop()
|
||||
@@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
if err := e.removeAllPeers(); err != nil {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
}
|
||||
|
||||
@@ -405,8 +414,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
err = e.wgInterfaceCreate()
|
||||
if err != nil {
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
@@ -442,6 +450,11 @@ func (e *Engine) Start() error {
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
|
||||
@@ -450,7 +463,6 @@ func (e *Engine) Start() error {
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -550,6 +562,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
var modified []*mgmProto.RemotePeerConfig
|
||||
for _, p := range peersUpdate {
|
||||
peerPubKey := p.GetWgPubKey()
|
||||
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPeer.AgentVersionString() != p.AgentVersion {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
|
||||
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
||||
if !ok {
|
||||
continue
|
||||
@@ -559,8 +581,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
|
||||
if err != nil {
|
||||
if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
|
||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
||||
}
|
||||
}
|
||||
@@ -621,16 +642,11 @@ func (e *Engine) removePeer(peerKey string) error {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
if err != nil {
|
||||
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
||||
}
|
||||
}()
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
conn, exists := e.peerStore.Remove(peerKey)
|
||||
if exists {
|
||||
conn.Close()
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
if err != nil {
|
||||
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -952,12 +968,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||
log.Errorf("failed to update local IPs: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
|
||||
// This needs to be toggled before applying routes.
|
||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
@@ -976,7 +1004,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
|
||||
@@ -1022,6 +1051,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
|
||||
e.connMgr.SetExcludeList(excludedLazyPeers)
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
@@ -1155,7 +1188,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||
PubKey: offlinePeer.GetWgPubKey(),
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusDisconnected,
|
||||
ConnStatus: peer.StatusIdle,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
@@ -1191,12 +1224,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
peerIPs = append(peerIPs, allowedNetIP)
|
||||
}
|
||||
|
||||
conn, err := e.createPeerConn(peerKey, peerIPs)
|
||||
conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
|
||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
@@ -1205,17 +1243,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
conn.Open()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
|
||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
|
||||
log.Debugf("creating peer connection %s", pubKey)
|
||||
|
||||
wgConfig := peer.WgConfig{
|
||||
@@ -1229,11 +1260,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
||||
// randomize connection timeout
|
||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||
config := peer.ConnConfig{
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
AgentVersion: agentVersion,
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassConfig: peer.RosenpassConfig{
|
||||
PubKey: e.getRosenpassPubKey(),
|
||||
Addr: e.getRosenpassAddr(),
|
||||
@@ -1249,7 +1281,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
||||
},
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
PeerConnDispatcher: e.peerConnDispatcher,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1270,7 +1311,7 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
}
|
||||
@@ -1578,13 +1619,39 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes() bool {
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
||||
|
||||
managementHealthy := e.mgmClient.IsHealthy()
|
||||
log.Debugf("management health check: healthy=%t", managementHealthy)
|
||||
|
||||
results := append(e.probeSTUNs(), e.probeTURNs()...)
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
turns := slices.Clone(e.TURNs)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
stats, err := e.wgInterface.GetStats()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get wireguard stats: %v", err)
|
||||
e.syncMsgMux.Unlock()
|
||||
return false
|
||||
}
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
wgStats, ok := stats[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
results := e.probeICE(stuns, turns)
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
@@ -1596,37 +1663,16 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
wgStats, err := e.wgInterface.GetStats(key)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||
continue
|
||||
}
|
||||
// wgStats could be zero value, in which case we just reset the stats
|
||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
||||
e.syncMsgMux.Lock()
|
||||
stuns := slices.Clone(e.STUNs)
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
|
||||
}
|
||||
|
||||
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
e.syncMsgMux.Lock()
|
||||
turns := slices.Clone(e.TURNs)
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
||||
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
return append(
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
@@ -1813,21 +1859,21 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return ip.Unmap(), nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
if e.firewall == nil {
|
||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(rules) == 0 {
|
||||
if e.ingressGatewayMgr == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err := e.ingressGatewayMgr.Close()
|
||||
e.ingressGatewayMgr = nil
|
||||
e.statusRecorder.SetIngressGwMgr(nil)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.ingressGatewayMgr == nil {
|
||||
@@ -1878,7 +1924,35 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||
log.Errorf("failed to update forwarding rules: %v", err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return forwardingRules, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
|
||||
excludedPeers := make(map[string]bool)
|
||||
for _, r := range routes {
|
||||
if r.Peer == "" {
|
||||
continue
|
||||
}
|
||||
if !excludedPeers[r.Peer] {
|
||||
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
|
||||
excludedPeers[r.Peer] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
ip := r.TranslatedAddress
|
||||
for _, p := range peers {
|
||||
for _, allowedIP := range p.GetAllowedIps() {
|
||||
if allowedIP != ip.String() {
|
||||
continue
|
||||
}
|
||||
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
|
||||
excludedPeers[p.GetWgPubKey()] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return excludedPeers
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
|
||||
@@ -28,8 +28,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -38,6 +36,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -53,6 +52,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
@@ -93,7 +93,7 @@ type MockWGIface struct {
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
@@ -171,8 +171,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||
return m.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return m.GetStatsFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
@@ -378,6 +378,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
@@ -400,6 +403,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -770,6 +775,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
|
||||
engine.routeManager = mockRouteManager
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
exitErr := engine.Stop()
|
||||
@@ -966,6 +973,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
engine.dnsServer = mockDNSServer
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
exitErr := engine.Stop()
|
||||
@@ -1476,7 +1485,7 @@ func getConnectedPeers(e *Engine) int {
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.Status() == peer.StatusConnected {
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,6 @@ type wgIfaceBase interface {
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
9
client/internal/lazyconn/activity/listen_ip.go
Normal file
9
client/internal/lazyconn/activity/listen_ip.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package activity
|
||||
|
||||
import "net"
|
||||
|
||||
var (
|
||||
listenIP = net.IP{127, 0, 0, 1}
|
||||
)
|
||||
10
client/internal/lazyconn/activity/listen_ip_linux.go
Normal file
10
client/internal/lazyconn/activity/listen_ip_linux.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !android
|
||||
|
||||
package activity
|
||||
|
||||
import "net"
|
||||
|
||||
var (
|
||||
// use this ip to avoid eBPF proxy congestion
|
||||
listenIP = net.IP{127, 0, 1, 1}
|
||||
)
|
||||
106
client/internal/lazyconn/activity/listener.go
Normal file
106
client/internal/lazyconn/activity/listener.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||
type Listener struct {
|
||||
wgIface lazyconn.WGIface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
conn *net.UDPConn
|
||||
endpoint *net.UDPAddr
|
||||
done sync.Mutex
|
||||
|
||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||
}
|
||||
|
||||
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
d := &Listener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
}
|
||||
|
||||
conn, err := d.newConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
|
||||
}
|
||||
d.conn = conn
|
||||
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
if err := d.createEndpoint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.done.Lock()
|
||||
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Listener) ReadPackets() {
|
||||
for {
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
if err != nil {
|
||||
if d.isClosed.Load() {
|
||||
d.peerCfg.Log.Debugf("exit from activity listener")
|
||||
} else {
|
||||
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if n < 1 {
|
||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := d.removeEndpoint(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
|
||||
d.done.Unlock()
|
||||
}
|
||||
|
||||
func (d *Listener) Close() {
|
||||
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
|
||||
d.isClosed.Store(true)
|
||||
|
||||
if err := d.conn.Close(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
|
||||
}
|
||||
d.done.Lock()
|
||||
}
|
||||
|
||||
func (d *Listener) removeEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (d *Listener) createEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
||||
}
|
||||
|
||||
func (d *Listener) newConn() (*net.UDPConn, error) {
|
||||
addr := &net.UDPAddr{
|
||||
Port: 0,
|
||||
IP: listenIP,
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create activity listener on %s: %s", addr, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
41
client/internal/lazyconn/activity/listener_test.go
Normal file
41
client/internal/lazyconn/activity/listener_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestNewListener(t *testing.T) {
|
||||
peer := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
l, err := NewListener(MocWGIface{}, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
chanClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(chanClosed)
|
||||
l.ReadPackets()
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
l.Close()
|
||||
|
||||
select {
|
||||
case <-chanClosed:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
95
client/internal/lazyconn/activity/manager.go
Normal file
95
client/internal/lazyconn/activity/manager.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
OnActivityChan chan peerid.ConnID
|
||||
|
||||
wgIface lazyconn.WGIface
|
||||
|
||||
peers map[peerid.ConnID]*Listener
|
||||
done chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(wgIface lazyconn.WGIface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
peers: make(map[peerid.ConnID]*Listener),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
|
||||
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
listener, err := NewListener(m.wgIface, peerCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.peers[peerCfg.PeerConnID] = listener
|
||||
|
||||
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
listener, ok := m.peers[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Debugf("removing activity listener")
|
||||
delete(m.peers, peerConnID)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
func (m *Manager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
close(m.done)
|
||||
for peerID, listener := range m.peers {
|
||||
delete(m.peers, peerID)
|
||||
listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
|
||||
listener.ReadPackets()
|
||||
|
||||
m.mu.Lock()
|
||||
if _, ok := m.peers[peerConnID]; !ok {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(m.peers, peerConnID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.notify(peerConnID)
|
||||
}
|
||||
|
||||
func (m *Manager) notify(peerConnID peerid.ConnID) {
|
||||
select {
|
||||
case <-m.done:
|
||||
case m.OnActivityChan <- peerConnID:
|
||||
}
|
||||
}
|
||||
162
client/internal/lazyconn/activity/manager_test.go
Normal file
162
client/internal/lazyconn/activity/manager_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
return peerid.ConnID(m)
|
||||
}
|
||||
|
||||
type MocWGIface struct {
|
||||
}
|
||||
|
||||
func (m MocWGIface) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
peer1 := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
mgr := NewManager(mocWgInterface)
|
||||
defer mgr.Close()
|
||||
peerCfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
if peerConnID != peerCfg1.PeerConnID {
|
||||
t.Fatalf("unexpected peerConnID: %v", peerConnID)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_RemovePeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
peer1 := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
mgr := NewManager(mocWgInterface)
|
||||
defer mgr.Close()
|
||||
|
||||
peerCfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
||||
|
||||
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||
|
||||
if err := trigger(addr); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-mgr.OnActivityChan:
|
||||
t.Fatal("should not have active activity")
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
mocWgInterface := &MocWGIface{}
|
||||
|
||||
peer1 := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
mgr := NewManager(mocWgInterface)
|
||||
defer mgr.Close()
|
||||
|
||||
peerCfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
peer2 := &MocPeer{}
|
||||
peerCfg2 := lazyconn.PeerConfig{
|
||||
PublicKey: peer2.PeerID,
|
||||
PeerConnID: peer2.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey2"),
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-mgr.OnActivityChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timed out waiting for activity")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func trigger(addr string) error {
|
||||
// Create a connection to the destination UDP address and port
|
||||
conn, err := net.Dial("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Write the bytes to the UDP connection
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
32
client/internal/lazyconn/doc.go
Normal file
32
client/internal/lazyconn/doc.go
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
|
||||
|
||||
## Overview
|
||||
|
||||
The package includes a `Manager` component responsible for:
|
||||
- Managing lazy connections activated on-demand
|
||||
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
|
||||
- Maintaining a list of excluded peers that should always have permanent connections
|
||||
- Handling remote peer connection initiatives based on peer signaling
|
||||
|
||||
## Thread-Safe Operations
|
||||
|
||||
The `Manager` ensures thread safety across multiple operations, categorized by caller:
|
||||
|
||||
- **Engine (single goroutine)**:
|
||||
- `AddPeer`: Adds a peer to the connection manager.
|
||||
- `RemovePeer`: Removes a peer from the connection manager.
|
||||
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
|
||||
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
|
||||
|
||||
- **Connection Dispatcher (any peer routine)**:
|
||||
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
|
||||
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
|
||||
|
||||
- **Activity Manager**:
|
||||
- `onPeerActivity`: Run peer.Open(context).
|
||||
|
||||
- **Inactivity Monitor**:
|
||||
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
|
||||
*/
|
||||
package lazyconn
|
||||
26
client/internal/lazyconn/env.go
Normal file
26
client/internal/lazyconn/env.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
|
||||
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
|
||||
)
|
||||
|
||||
func IsLazyConnEnabledByEnv() bool {
|
||||
val := os.Getenv(EnvEnableLazyConn)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
enabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
peer "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
|
||||
MinimumInactivityThreshold = 3 * time.Minute
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
id peer.ConnID
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
inactivityThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
|
||||
i := &Monitor{
|
||||
id: peerID,
|
||||
timer: time.NewTimer(0),
|
||||
inactivityThreshold: threshold,
|
||||
}
|
||||
i.timer.Stop()
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
defer i.timer.Stop()
|
||||
|
||||
ctx, i.cancel = context.WithCancel(ctx)
|
||||
defer func() {
|
||||
defer i.cancel()
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
select {
|
||||
case timeoutChan <- i.id:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Monitor) Stop() {
|
||||
if i.cancel == nil {
|
||||
return
|
||||
}
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
func (i *Monitor) PauseTimer() {
|
||||
i.timer.Stop()
|
||||
}
|
||||
|
||||
func (i *Monitor) ResetTimer() {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
}
|
||||
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
}
|
||||
|
||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
return peerid.ConnID(m)
|
||||
}
|
||||
|
||||
func TestInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseInactivityMonitor(t *testing.T) {
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
for i := 2; i > 0; i-- {
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(testTimeoutCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
testTimeoutCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
im.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPauseInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
trashHold := time.Second * 3
|
||||
im := NewInactivityMonitor(p.ConnID(), trashHold)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(ctx, timeoutChan)
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second) // grant time to start the monitor
|
||||
im.PauseTimer()
|
||||
|
||||
// check to do not receive timeout
|
||||
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
|
||||
defer thresholdCancel()
|
||||
select {
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-thresholdCtx.Done():
|
||||
// test ok
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
}
|
||||
|
||||
// test reset timer
|
||||
im.ResetTimer()
|
||||
|
||||
select {
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
// expected timeout
|
||||
}
|
||||
}
|
||||
404
client/internal/lazyconn/manager/manager.go
Normal file
404
client/internal/lazyconn/manager/manager.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
)
|
||||
|
||||
const (
|
||||
watcherActivity watcherType = iota
|
||||
watcherInactivity
|
||||
)
|
||||
|
||||
type watcherType int
|
||||
|
||||
type managedPeer struct {
|
||||
peerCfg *lazyconn.PeerConfig
|
||||
expectedWatcher watcherType
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
InactivityThreshold *time.Duration
|
||||
}
|
||||
|
||||
// Manager manages lazy connections
|
||||
// It is responsible for:
|
||||
// - Managing lazy connections activated on-demand
|
||||
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
|
||||
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||
// - Handling connection establishment based on peer signaling
|
||||
type Manager struct {
|
||||
peerStore *peerstore.Store
|
||||
connStateDispatcher *dispatcher.ConnectionDispatcher
|
||||
inactivityThreshold time.Duration
|
||||
|
||||
connStateListener *dispatcher.ConnectionListener
|
||||
managedPeers map[string]*lazyconn.PeerConfig
|
||||
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
||||
excludes map[string]lazyconn.PeerConfig
|
||||
managedPeersMu sync.Mutex
|
||||
|
||||
activityManager *activity.Manager
|
||||
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
|
||||
|
||||
cancel context.CancelFunc
|
||||
onInactive chan peerid.ConnID
|
||||
}
|
||||
|
||||
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
|
||||
log.Infof("setup lazy connection service")
|
||||
m := &Manager{
|
||||
peerStore: peerStore,
|
||||
connStateDispatcher: connStateDispatcher,
|
||||
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
||||
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
||||
excludes: make(map[string]lazyconn.PeerConfig),
|
||||
activityManager: activity.NewManager(wgIface),
|
||||
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
|
||||
onInactive: make(chan peerid.ConnID),
|
||||
}
|
||||
|
||||
if config.InactivityThreshold != nil {
|
||||
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
|
||||
m.inactivityThreshold = *config.InactivityThreshold
|
||||
} else {
|
||||
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
m.connStateListener = &dispatcher.ConnectionListener{
|
||||
OnConnected: m.onPeerConnected,
|
||||
OnDisconnected: m.onPeerDisconnected,
|
||||
}
|
||||
|
||||
connStateDispatcher.AddListener(m.connStateListener)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Start starts the manager and listens for peer activity and inactivity events
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
defer m.close()
|
||||
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(ctx, peerConnID)
|
||||
case peerConnID := <-m.onInactive:
|
||||
m.onPeerInactivityTimedOut(peerConnID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExcludePeer marks peers for a permanent connection
|
||||
// It removes peers from the managed list if they are added to the exclude list
|
||||
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
||||
// this case, we suppose that the connection status is connected or connecting.
|
||||
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
||||
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
added := make([]string, 0)
|
||||
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
|
||||
|
||||
for _, peerCfg := range peerConfigs {
|
||||
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
|
||||
excludes[peerCfg.PublicKey] = peerCfg
|
||||
}
|
||||
|
||||
// if a peer is newly added to the exclude list, remove from the managed peers list
|
||||
for pubKey, peerCfg := range excludes {
|
||||
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
|
||||
continue
|
||||
}
|
||||
|
||||
added = append(added, pubKey)
|
||||
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
|
||||
m.removePeer(pubKey)
|
||||
}
|
||||
|
||||
// if a peer has been removed from exclude list then it should be added to the managed peers
|
||||
for pubKey, peerCfg := range m.excludes {
|
||||
if _, stillExcluded := excludes[pubKey]; stillExcluded {
|
||||
continue
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
||||
|
||||
if err := m.addActivePeer(ctx, peerCfg); err != nil {
|
||||
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
m.excludes = excludes
|
||||
return added
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
peerCfg.Log.Debugf("adding peer to lazy connection manager")
|
||||
|
||||
_, exists := m.excludes[peerCfg.PublicKey]
|
||||
if exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherActivity,
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// AddActivePeers adds a list of peers to the lazy connection manager
|
||||
// suppose these peers was in connected or in connecting states
|
||||
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
for _, cfg := range peerCfg {
|
||||
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
|
||||
cfg.Log.Errorf("peer already managed")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.addActivePeer(ctx, cfg); err != nil {
|
||||
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(peerID string) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
m.removePeer(peerID)
|
||||
}
|
||||
|
||||
// ActivatePeer activates a peer connection when a signal message is received
|
||||
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
cfg, ok := m.managedPeers[peerID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// signal messages coming continuously after success activation, with this avoid the multiple activation
|
||||
if mp.expectedWatcher == watcherInactivity {
|
||||
return false
|
||||
}
|
||||
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
|
||||
im, ok := m.inactivityMonitors[cfg.PeerConnID]
|
||||
if !ok {
|
||||
cfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return false
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("starting inactivity monitor")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
return nil
|
||||
}
|
||||
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherInactivity,
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) removePeer(peerID string) {
|
||||
cfg, ok := m.managedPeers[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Log.Infof("removing lazy peer")
|
||||
|
||||
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
|
||||
im.Stop()
|
||||
delete(m.inactivityMonitors, cfg.PeerConnID)
|
||||
cfg.Log.Debugf("inactivity monitor stopped")
|
||||
}
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
delete(m.managedPeers, peerID)
|
||||
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
||||
}
|
||||
|
||||
func (m *Manager) close() {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
m.cancel()
|
||||
|
||||
m.connStateDispatcher.RemoveListener(m.connStateListener)
|
||||
m.activityManager.Close()
|
||||
for _, iw := range m.inactivityMonitors {
|
||||
iw.Stop()
|
||||
}
|
||||
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
|
||||
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
||||
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
||||
log.Infof("lazy connection manager closed")
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by conn id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherActivity {
|
||||
mp.peerCfg.Log.Warnf("ignore activity event")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("detected peer activity")
|
||||
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
|
||||
mp.peerCfg.Log.Infof("starting inactivity monitor")
|
||||
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
|
||||
|
||||
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
|
||||
// this is blocking operation, potentially can be optimized
|
||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
// just in case free up
|
||||
m.inactivityMonitors[peerConnID].PauseTimer()
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
|
||||
iw.PauseTimer()
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
|
||||
iw.ResetTimer()
|
||||
}
|
||||
16
client/internal/lazyconn/peercfg.go
Normal file
16
client/internal/lazyconn/peercfg.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type PeerConfig struct {
|
||||
PublicKey string
|
||||
AllowedIPs []netip.Prefix
|
||||
PeerConnID id.ConnID
|
||||
Log *log.Entry
|
||||
}
|
||||
41
client/internal/lazyconn/support.go
Normal file
41
client/internal/lazyconn/support.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
var (
|
||||
minVersion = version.Must(version.NewVersion("0.45.0"))
|
||||
)
|
||||
|
||||
func IsSupported(agentVersion string) bool {
|
||||
if agentVersion == "development" {
|
||||
return true
|
||||
}
|
||||
|
||||
// filter out versions like this: a6c5960, a7d5c522, d47be154
|
||||
if !strings.Contains(agentVersion, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
normalizedVersion := normalizeVersion(agentVersion)
|
||||
inputVer, err := version.NewVersion(normalizedVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return inputVer.GreaterThanOrEqual(minVersion)
|
||||
}
|
||||
|
||||
func normalizeVersion(version string) string {
|
||||
// Remove prefixes like 'v' or 'a'
|
||||
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
|
||||
version = version[1:]
|
||||
}
|
||||
|
||||
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
|
||||
parts := strings.Split(version, "-")
|
||||
return parts[0]
|
||||
}
|
||||
31
client/internal/lazyconn/support_test.go
Normal file
31
client/internal/lazyconn/support_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package lazyconn
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
version string
|
||||
want bool
|
||||
}{
|
||||
{"development", true},
|
||||
{"0.45.0", true},
|
||||
{"v0.45.0", true},
|
||||
{"0.45.1", true},
|
||||
{"0.45.1-SNAPSHOT-559e6731", true},
|
||||
{"v0.45.1-dev", true},
|
||||
{"a7d5c522", false},
|
||||
{"0.9.6", false},
|
||||
{"0.9.6-SNAPSHOT", false},
|
||||
{"0.9.6-SNAPSHOT-2033650", false},
|
||||
{"meta_wt_version", false},
|
||||
{"v0.31.1-dev", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.version, func(t *testing.T) {
|
||||
if got := IsSupported(tt.version); got != tt.want {
|
||||
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
14
client/internal/lazyconn/wgiface.go
Normal file
14
client/internal/lazyconn/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
}
|
||||
@@ -123,8 +123,14 @@ func (m *Manager) disableFlow() error {
|
||||
|
||||
m.logger.Close()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
if m.receiverClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.receiverClient.Close()
|
||||
m.receiverClient = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("close: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -17,8 +17,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -26,32 +30,20 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
type ConnPriority int
|
||||
|
||||
func (cp ConnPriority) String() string {
|
||||
switch cp {
|
||||
case connPriorityNone:
|
||||
return "None"
|
||||
case connPriorityRelay:
|
||||
return "PriorityRelay"
|
||||
case connPriorityICETurn:
|
||||
return "PriorityICETurn"
|
||||
case connPriorityICEP2P:
|
||||
return "PriorityICEP2P"
|
||||
default:
|
||||
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
|
||||
connPriorityNone ConnPriority = 0
|
||||
connPriorityRelay ConnPriority = 1
|
||||
connPriorityICETurn ConnPriority = 2
|
||||
connPriorityICEP2P ConnPriority = 3
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
StatusRecorder *Status
|
||||
Signaler *Signaler
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
Semaphore *semaphoregroup.SemaphoreGroup
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
}
|
||||
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
@@ -76,6 +68,8 @@ type ConnConfig struct {
|
||||
// LocalKey is a public key of a local peer
|
||||
LocalKey string
|
||||
|
||||
AgentVersion string
|
||||
|
||||
Timeout time.Duration
|
||||
|
||||
WgConfig WgConfig
|
||||
@@ -89,22 +83,23 @@ type ConnConfig struct {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
log *log.Entry
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
handshaker *Handshaker
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
|
||||
statusRelay *AtomicConnStatus
|
||||
statusICE *AtomicConnStatus
|
||||
currentConnPriority ConnPriority
|
||||
statusRelay *worker.AtomicWorkerStatus
|
||||
statusICE *worker.AtomicWorkerStatus
|
||||
currentConnPriority conntype.ConnPriority
|
||||
opened bool // this flag is used to prevent close in case of not opened connection
|
||||
|
||||
workerICE *WorkerICE
|
||||
@@ -120,9 +115,12 @@ type Conn struct {
|
||||
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -130,91 +128,101 @@ type Conn struct {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
if len(config.WgConfig.AllowedIps) == 0 {
|
||||
return nil, fmt.Errorf("allowed IPs is empty")
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(engineCtx)
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
log: connLog,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
semaphore: semaphore,
|
||||
dumpState: newStateDump(config.Key, connLog, statusRecorder),
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
peerConnDispatcher: services.PeerConnDispatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
}
|
||||
|
||||
ctrl := isController(config)
|
||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
|
||||
|
||||
go conn.handshaker.Listen()
|
||||
|
||||
go conn.dumpState.Start(ctx)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Open opens connection to the remote peer
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open() {
|
||||
conn.semaphore.Add(conn.ctx)
|
||||
conn.log.Debugf("open connection to peer")
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.semaphore.Add(engineCtx)
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
conn.opened = true
|
||||
|
||||
if conn.opened {
|
||||
conn.semaphore.Done(engineCtx)
|
||||
return nil
|
||||
}
|
||||
|
||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
conn.handshaker.Listen(conn.ctx)
|
||||
}()
|
||||
go conn.dumpState.Start(conn.ctx)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: StatusDisconnected,
|
||||
ConnStatus: StatusConnecting,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("error while updating the state err: %v", err)
|
||||
if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
|
||||
conn.Log.Warnf("error while updating the state err: %v", err)
|
||||
}
|
||||
|
||||
go conn.startHandshakeAndReconnect(conn.ctx)
|
||||
}
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||
conn.semaphore.Done(conn.ctx)
|
||||
|
||||
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||
defer conn.semaphore.Done(conn.ctx)
|
||||
conn.waitInitialRandomSleepTime(ctx)
|
||||
conn.dumpState.SendOffer()
|
||||
if err := conn.handshaker.sendOffer(); err != nil {
|
||||
conn.Log.Errorf("failed to send initial offer: %v", err)
|
||||
}
|
||||
|
||||
conn.dumpState.SendOffer()
|
||||
err := conn.handshaker.sendOffer()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
||||
}
|
||||
|
||||
go conn.guard.Start(ctx)
|
||||
go conn.listenGuardEvent(ctx)
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
conn.wg.Done()
|
||||
}()
|
||||
}()
|
||||
conn.opened = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
@@ -223,14 +231,14 @@ func (conn *Conn) Close() {
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
conn.log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
if !conn.opened {
|
||||
conn.log.Debugf("ignore close connection to peer")
|
||||
conn.Log.Debugf("ignore close connection to peer")
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
@@ -238,7 +246,7 @@ func (conn *Conn) Close() {
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to close wg proxy for relay: %v", err)
|
||||
conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
|
||||
}
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
@@ -246,13 +254,13 @@ func (conn *Conn) Close() {
|
||||
if conn.wgProxyICE != nil {
|
||||
err := conn.wgProxyICE.CloseConn()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to close wg proxy for ice: %v", err)
|
||||
conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
|
||||
}
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
conn.freeUpConnID()
|
||||
@@ -262,14 +270,16 @@ func (conn *Conn) Close() {
|
||||
}
|
||||
|
||||
conn.setStatusToDisconnected()
|
||||
conn.log.Infof("peer connection has been closed")
|
||||
conn.opened = false
|
||||
conn.wg.Wait()
|
||||
conn.Log.Infof("peer connection closed")
|
||||
}
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
conn.dumpState.RemoteAnswer()
|
||||
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
||||
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
||||
return conn.handshaker.OnRemoteAnswer(answer)
|
||||
}
|
||||
|
||||
@@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
||||
|
||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
||||
conn.dumpState.RemoteOffer()
|
||||
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||
return conn.handshaker.OnRemoteOffer(offer)
|
||||
}
|
||||
|
||||
@@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
|
||||
return conn.config.WgConfig
|
||||
}
|
||||
|
||||
// Status returns current status of the Conn
|
||||
func (conn *Conn) Status() ConnStatus {
|
||||
// IsConnected unit tests only
|
||||
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
|
||||
func (conn *Conn) IsConnected() bool {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
return conn.evalStatus()
|
||||
return conn.currentConnPriority != conntype.None
|
||||
}
|
||||
|
||||
func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
func (conn *Conn) ConnID() id.ConnID {
|
||||
return id.ConnID(conn)
|
||||
}
|
||||
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
||||
func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
return
|
||||
}
|
||||
|
||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||
conn.log.Errorf("remote ICE connection is nil")
|
||||
if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
|
||||
conn.Log.Errorf("remote ICE connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
||||
// todo consider to remove this check
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("set ICE to active connection")
|
||||
conn.Log.Infof("set ICE to active connection")
|
||||
conn.dumpState.P2PConnected()
|
||||
|
||||
var (
|
||||
@@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
conn.dumpState.NewLocalProxy()
|
||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
ep = wgProxy.EndpointAddr()
|
||||
@@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
@@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
oldState := conn.currentConnPriority
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
|
||||
if oldState == conntype.None {
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
@@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Tracef("ICE connection state changed to disconnected")
|
||||
conn.Log.Tracef("ICE connection state changed to disconnected")
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// switch back to relay connection
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||
conn.dumpState.SwitchToRelay()
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
@@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != StatusDisconnected
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
if changed {
|
||||
conn.guard.SetICEConnDisconnected()
|
||||
}
|
||||
conn.statusICE.Set(StatusDisconnected)
|
||||
conn.statusICE.SetDisconnected()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
if err := rci.relayedConn.Close(); err != nil {
|
||||
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.dumpState.RelayConnected()
|
||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
|
||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
conn.dumpState.NewLocalProxy()
|
||||
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
|
||||
if conn.isICEActive() {
|
||||
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.Log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
@@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("relay connection is disconnected")
|
||||
conn.Log.Debugf("relay connection is disconnected")
|
||||
|
||||
if conn.currentConnPriority == connPriorityRelay {
|
||||
conn.log.Infof("clean up WireGuard config")
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
changed := conn.statusRelay.Get() != worker.StatusDisconnected
|
||||
if changed {
|
||||
conn.guard.SetRelayedConnDisconnected()
|
||||
}
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.statusRelay.SetDisconnected()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-conn.guard.Reconnect:
|
||||
conn.log.Infof("send offer to peer")
|
||||
conn.dumpState.SendOffer()
|
||||
if err := conn.handshaker.SendOffer(); err != nil {
|
||||
conn.log.Errorf("failed to send offer: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
func (conn *Conn) onGuardEvent() {
|
||||
conn.Log.Debugf("send offer to peer")
|
||||
conn.dumpState.SendOffer()
|
||||
if err := conn.handshaker.SendOffer(); err != nil {
|
||||
conn.Log.Errorf("failed to send offer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err)
|
||||
conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerICEState(peerState)
|
||||
if err != nil {
|
||||
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err)
|
||||
conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setStatusToDisconnected() {
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
conn.statusICE.Set(StatusDisconnected)
|
||||
conn.statusRelay.SetDisconnected()
|
||||
conn.statusICE.SetDisconnected()
|
||||
conn.currentConnPriority = conntype.None
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: StatusDisconnected,
|
||||
ConnStatus: StatusIdle,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
@@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
|
||||
if err != nil {
|
||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||
// todo rethink status updates
|
||||
conn.log.Debugf("error while updating peer's state, err: %v", err)
|
||||
conn.Log.Debugf("error while updating peer's state, err: %v", err)
|
||||
}
|
||||
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
|
||||
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
||||
conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -656,32 +674,24 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay, conntype.ICETurn:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.currentConnPriority == connPriorityICEP2P {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) evalStatus() ConnStatus {
|
||||
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
|
||||
if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
|
||||
return StatusConnected
|
||||
}
|
||||
|
||||
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
return StatusDisconnected
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
defer func() {
|
||||
if !connected {
|
||||
@@ -689,12 +699,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.statusICE.Get() == StatusDisconnected {
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() != StatusConnected {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -716,7 +726,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDRelay); err != nil {
|
||||
conn.log.Errorf("After remove peer hook failed: %v", err)
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDRelay = ""
|
||||
@@ -725,7 +735,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDICE != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDICE); err != nil {
|
||||
conn.log.Errorf("After remove peer hook failed: %v", err)
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDICE = ""
|
||||
@@ -733,7 +743,7 @@ func (conn *Conn) freeUpConnID() {
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
|
||||
Port: conn.config.WgConfig.WgListenPort,
|
||||
@@ -741,18 +751,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
|
||||
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
}
|
||||
|
||||
func (conn *Conn) isICEActive() bool {
|
||||
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
|
||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
@@ -760,10 +770,10 @@ func (conn *Conn) removeWgPeer() error {
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
}
|
||||
}
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -773,16 +783,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
|
||||
func (conn *Conn) logTraceConnState() {
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||
} else {
|
||||
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
||||
conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = proxy
|
||||
@@ -793,6 +803,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
|
||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||
}
|
||||
|
||||
func (conn *Conn) AgentVersionString() string {
|
||||
return conn.config.AgentVersion
|
||||
}
|
||||
|
||||
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
if conn.config.RosenpassConfig.PubKey == nil {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
@@ -804,7 +818,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
|
||||
@@ -1,58 +1,29 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
StatusConnected ConnStatus = iota
|
||||
// StatusIdle indicate the peer is in disconnected state
|
||||
StatusIdle ConnStatus = iota
|
||||
// StatusConnecting indicate the peer is in connecting state
|
||||
StatusConnecting
|
||||
// StatusDisconnected indicate the peer is in disconnected state
|
||||
StatusDisconnected
|
||||
// StatusConnected indicate the peer is in connected state
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
||||
type AtomicConnStatus struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
||||
func NewAtomicConnStatus() *AtomicConnStatus {
|
||||
acs := &AtomicConnStatus{}
|
||||
acs.Set(StatusDisconnected)
|
||||
return acs
|
||||
}
|
||||
|
||||
// Get returns the current connection status
|
||||
func (acs *AtomicConnStatus) Get() ConnStatus {
|
||||
return ConnStatus(acs.status.Load())
|
||||
}
|
||||
|
||||
// Set updates the connection status
|
||||
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
||||
acs.status.Store(int32(status))
|
||||
}
|
||||
|
||||
// String returns the string representation of the current status
|
||||
func (acs *AtomicConnStatus) String() string {
|
||||
return acs.Get().String()
|
||||
}
|
||||
|
||||
func (s ConnStatus) String() string {
|
||||
switch s {
|
||||
case StatusConnecting:
|
||||
return "Connecting"
|
||||
case StatusConnected:
|
||||
return "Connected"
|
||||
case StatusDisconnected:
|
||||
return "Disconnected"
|
||||
case StatusIdle:
|
||||
return "Idle"
|
||||
default:
|
||||
log.Errorf("unknown status: %d", s)
|
||||
return "INVALID_PEER_CONNECTION_STATUS"
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, "Connected"},
|
||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
||||
{"StatusIdle", StatusIdle, "Idle"},
|
||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||
}
|
||||
|
||||
@@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
|
||||
var connConf = ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
@@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
|
||||
sd := ServiceDependencies{
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables := []struct {
|
||||
name string
|
||||
statusIce ConnStatus
|
||||
statusRelay ConnStatus
|
||||
want ConnStatus
|
||||
}{
|
||||
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
|
||||
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
|
||||
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
|
||||
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
|
||||
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
|
||||
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
|
||||
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
t.Run(table.name, func(t *testing.T) {
|
||||
si := NewAtomicConnStatus()
|
||||
si.Set(table.statusIce)
|
||||
conn.statusICE = si
|
||||
|
||||
sr := NewAtomicConnStatus()
|
||||
sr.Set(table.statusRelay)
|
||||
conn.statusRelay = sr
|
||||
|
||||
got := conn.Status()
|
||||
assert.Equal(t, got, table.want, "they should be equal")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_presharedKey(t *testing.T) {
|
||||
conn1 := Conn{
|
||||
|
||||
29
client/internal/peer/conntype/priority.go
Normal file
29
client/internal/peer/conntype/priority.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package conntype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
None ConnPriority = 0
|
||||
Relay ConnPriority = 1
|
||||
ICETurn ConnPriority = 2
|
||||
ICEP2P ConnPriority = 3
|
||||
)
|
||||
|
||||
type ConnPriority int
|
||||
|
||||
func (cp ConnPriority) String() string {
|
||||
switch cp {
|
||||
case None:
|
||||
return "None"
|
||||
case Relay:
|
||||
return "PriorityRelay"
|
||||
case ICETurn:
|
||||
return "PriorityICETurn"
|
||||
case ICEP2P:
|
||||
return "PriorityICEP2P"
|
||||
default:
|
||||
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||
}
|
||||
}
|
||||
52
client/internal/peer/dispatcher/dispatcher.go
Normal file
52
client/internal/peer/dispatcher/dispatcher.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package dispatcher
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type ConnectionListener struct {
|
||||
OnConnected func(peerID id.ConnID)
|
||||
OnDisconnected func(peerID id.ConnID)
|
||||
}
|
||||
|
||||
type ConnectionDispatcher struct {
|
||||
listeners map[*ConnectionListener]struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectionDispatcher() *ConnectionDispatcher {
|
||||
return &ConnectionDispatcher{
|
||||
listeners: make(map[*ConnectionListener]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.listeners[listener] = struct{}{}
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
delete(e.listeners, listener)
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for listener := range e.listeners {
|
||||
listener.OnConnected(peerConnID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for listener := range e.listeners {
|
||||
listener.OnDisconnected(peerConnID)
|
||||
}
|
||||
}
|
||||
@@ -8,10 +8,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
reconnectMaxElapsedTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
@@ -25,7 +21,6 @@ type isConnectedFunc func() bool
|
||||
type Guard struct {
|
||||
Reconnect chan struct{}
|
||||
log *log.Entry
|
||||
isController bool
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
@@ -33,11 +28,10 @@ type Guard struct {
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
Reconnect: make(chan struct{}, 1),
|
||||
log: log,
|
||||
isController: isController,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
timeout: timeout,
|
||||
srWatcher: srWatcher,
|
||||
@@ -46,12 +40,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) Start(ctx context.Context) {
|
||||
if g.isController {
|
||||
g.reconnectLoopWithRetry(ctx)
|
||||
} else {
|
||||
g.listenForDisconnectEvents(ctx)
|
||||
}
|
||||
func (g *Guard) Start(ctx context.Context, eventCallback func()) {
|
||||
g.reconnectLoopWithRetry(ctx, eventCallback)
|
||||
}
|
||||
|
||||
func (g *Guard) SetRelayedConnDisconnected() {
|
||||
@@ -68,9 +58,9 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
waitForInitialConnectionTry(ctx)
|
||||
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
@@ -93,7 +83,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
g.triggerOfferSending()
|
||||
callback()
|
||||
}
|
||||
|
||||
case <-g.relayedConnDisconnected:
|
||||
@@ -121,39 +111,12 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
|
||||
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
|
||||
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
|
||||
// mean that to switch to it. We always force to use the higher priority connection.
|
||||
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
|
||||
g.log.Infof("start listen for reconnect events...")
|
||||
for {
|
||||
select {
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, triggering reconnect")
|
||||
g.triggerOfferSending()
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE state changed, try to send new offer")
|
||||
g.triggerOfferSending()
|
||||
case <-srReconnectedChan:
|
||||
g.triggerOfferSending()
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
Multiplier: 2,
|
||||
MaxInterval: g.timeout,
|
||||
MaxElapsedTime: reconnectMaxElapsedTime,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
@@ -164,13 +127,6 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
return ticker
|
||||
}
|
||||
|
||||
func (g *Guard) triggerOfferSending() {
|
||||
select {
|
||||
case g.Reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Give chance to the peer to establish the initial connection.
|
||||
// With it, we can decrease to send necessary offer
|
||||
func waitForInitialConnectionTry(ctx context.Context) {
|
||||
|
||||
@@ -43,7 +43,6 @@ type OfferAnswer struct {
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
@@ -57,9 +56,8 @@ type Handshaker struct {
|
||||
remoteAnswerCh chan OfferAnswer
|
||||
}
|
||||
|
||||
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||
return &Handshaker{
|
||||
ctx: ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -74,10 +72,10 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen() {
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
h.log.Info("wait for remote offer confirmation")
|
||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
|
||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
|
||||
if err != nil {
|
||||
var connectionClosedError *ConnectionClosedError
|
||||
if errors.As(err, &connectionClosedError) {
|
||||
@@ -127,7 +125,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
||||
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
@@ -137,7 +135,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
||||
return &remoteOfferAnswer, nil
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
return &remoteOfferAnswer, nil
|
||||
case <-h.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
// closed externally
|
||||
return nil, NewConnectionClosedError(h.config.Key)
|
||||
}
|
||||
|
||||
5
client/internal/peer/id/connid.go
Normal file
5
client/internal/peer/id/connid.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package id
|
||||
|
||||
import "unsafe"
|
||||
|
||||
type ConnID unsafe.Pointer
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
type WGIface interface {
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ type notifier struct {
|
||||
currentClientState bool
|
||||
lastNotification int
|
||||
lastNumberOfPeers int
|
||||
lastFqdnAddress string
|
||||
lastIPAddress string
|
||||
}
|
||||
|
||||
func newNotifier() *notifier {
|
||||
@@ -25,15 +27,22 @@ func newNotifier() *notifier {
|
||||
}
|
||||
|
||||
func (n *notifier) setListener(listener Listener) {
|
||||
n.serverStateLock.Lock()
|
||||
lastNotification := n.lastNotification
|
||||
numOfPeers := n.lastNumberOfPeers
|
||||
fqdnAddress := n.lastFqdnAddress
|
||||
address := n.lastIPAddress
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
n.serverStateLock.Lock()
|
||||
n.notifyListener(listener, n.lastNotification)
|
||||
listener.OnPeersListChanged(n.lastNumberOfPeers)
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.listener = listener
|
||||
|
||||
listener.OnAddressChanged(fqdnAddress, address)
|
||||
notifyListener(listener, lastNotification)
|
||||
// run on go routine to avoid on Java layer to call go functions on same thread
|
||||
go listener.OnPeersListChanged(numOfPeers)
|
||||
}
|
||||
|
||||
func (n *notifier) removeListener() {
|
||||
@@ -44,41 +53,44 @@ func (n *notifier) removeListener() {
|
||||
|
||||
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
|
||||
calculatedState := n.calculateState(mgmState, signalState)
|
||||
|
||||
if !n.isServerStateChanged(calculatedState) {
|
||||
n.serverStateLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
n.lastNotification = calculatedState
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.notify(n.lastNotification)
|
||||
n.notify(calculatedState)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStart() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = true
|
||||
n.lastNotification = stateConnecting
|
||||
n.notify(n.lastNotification)
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.notify(stateConnecting)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStop() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = stateDisconnected
|
||||
n.notify(n.lastNotification)
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.notify(stateDisconnected)
|
||||
}
|
||||
|
||||
func (n *notifier) clientTearDown() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = stateDisconnecting
|
||||
n.notify(n.lastNotification)
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.notify(stateDisconnecting)
|
||||
}
|
||||
|
||||
func (n *notifier) isServerStateChanged(newState int) bool {
|
||||
@@ -87,26 +99,14 @@ func (n *notifier) isServerStateChanged(newState int) bool {
|
||||
|
||||
func (n *notifier) notify(state int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
listener := n.listener
|
||||
n.listenersLock.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return
|
||||
}
|
||||
n.notifyListener(n.listener, state)
|
||||
}
|
||||
|
||||
func (n *notifier) notifyListener(l Listener, state int) {
|
||||
go func() {
|
||||
switch state {
|
||||
case stateDisconnected:
|
||||
l.OnDisconnected()
|
||||
case stateConnected:
|
||||
l.OnConnected()
|
||||
case stateConnecting:
|
||||
l.OnConnecting()
|
||||
case stateDisconnecting:
|
||||
l.OnDisconnecting()
|
||||
}
|
||||
}()
|
||||
notifyListener(listener, state)
|
||||
}
|
||||
|
||||
func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
||||
@@ -126,20 +126,48 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
||||
}
|
||||
|
||||
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||
n.serverStateLock.Lock()
|
||||
n.lastNumberOfPeers = numOfPeers
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
listener := n.listener
|
||||
n.listenersLock.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return
|
||||
}
|
||||
n.listener.OnPeersListChanged(numOfPeers)
|
||||
|
||||
// run on go routine to avoid on Java layer to call go functions on same thread
|
||||
go listener.OnPeersListChanged(numOfPeers)
|
||||
}
|
||||
|
||||
func (n *notifier) localAddressChanged(fqdn, address string) {
|
||||
n.serverStateLock.Lock()
|
||||
n.lastFqdnAddress = fqdn
|
||||
n.lastIPAddress = address
|
||||
n.serverStateLock.Unlock()
|
||||
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
listener := n.listener
|
||||
n.listenersLock.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return
|
||||
}
|
||||
n.listener.OnAddressChanged(fqdn, address)
|
||||
|
||||
listener.OnAddressChanged(fqdn, address)
|
||||
}
|
||||
|
||||
func notifyListener(l Listener, state int) {
|
||||
switch state {
|
||||
case stateDisconnected:
|
||||
l.OnDisconnected()
|
||||
case stateConnected:
|
||||
l.OnConnected()
|
||||
case stateConnecting:
|
||||
l.OnConnecting()
|
||||
case stateDisconnecting:
|
||||
l.OnDisconnecting()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,14 +135,15 @@ type NSGroupState struct {
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
type FullStatus struct {
|
||||
Peers []State
|
||||
ManagementState ManagementState
|
||||
SignalState SignalState
|
||||
LocalPeerState LocalPeerState
|
||||
RosenpassState RosenpassState
|
||||
Relays []relay.ProbeResult
|
||||
NSGroupStates []NSGroupState
|
||||
NumOfForwardingRules int
|
||||
Peers []State
|
||||
ManagementState ManagementState
|
||||
SignalState SignalState
|
||||
LocalPeerState LocalPeerState
|
||||
RosenpassState RosenpassState
|
||||
Relays []relay.ProbeResult
|
||||
NSGroupStates []NSGroupState
|
||||
NumOfForwardingRules int
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
@@ -164,6 +165,7 @@ type Status struct {
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||
lazyConnectionEnabled bool
|
||||
|
||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||
@@ -219,7 +221,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||
}
|
||||
|
||||
// AddPeer adds peer to Daemon status map
|
||||
func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
||||
func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -229,7 +231,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
||||
}
|
||||
d.peers[peerPubKey] = State{
|
||||
PubKey: peerPubKey,
|
||||
ConnStatus: StatusDisconnected,
|
||||
IP: ip,
|
||||
ConnStatus: StatusIdle,
|
||||
FQDN: fqdn,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
@@ -511,9 +514,9 @@ func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
|
||||
switch {
|
||||
case receivedConnStatus == StatusConnecting:
|
||||
return true
|
||||
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
|
||||
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
|
||||
return true
|
||||
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
|
||||
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
|
||||
return curr.IP != ""
|
||||
default:
|
||||
return false
|
||||
@@ -689,6 +692,12 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
||||
d.rosenpassEnabled = rosenpassEnabled
|
||||
}
|
||||
|
||||
func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.lazyConnectionEnabled = enabled
|
||||
}
|
||||
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
@@ -761,6 +770,12 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) GetLazyConnection() bool {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.lazyConnectionEnabled
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
@@ -872,12 +887,13 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
fullStatus := FullStatus{
|
||||
ManagementState: d.GetManagementState(),
|
||||
SignalState: d.GetSignalState(),
|
||||
Relays: d.GetRelayStates(),
|
||||
RosenpassState: d.GetRosenpassState(),
|
||||
NSGroupStates: d.GetDNSStates(),
|
||||
NumOfForwardingRules: len(d.ForwardingRules()),
|
||||
ManagementState: d.GetManagementState(),
|
||||
SignalState: d.GetSignalState(),
|
||||
Relays: d.GetRelayStates(),
|
||||
RosenpassState: d.GetRosenpassState(),
|
||||
NSGroupStates: d.GetDNSStates(),
|
||||
NumOfForwardingRules: len(d.ForwardingRules()),
|
||||
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -10,22 +10,24 @@ import (
|
||||
|
||||
func TestAddPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "100.108.254.1"
|
||||
status := NewRecorder("https://mgm")
|
||||
err := status.AddPeer(key, "abc.netbird")
|
||||
err := status.AddPeer(key, "abc.netbird", ip)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
_, exists := status.peers[key]
|
||||
assert.True(t, exists, "value was found")
|
||||
|
||||
err = status.AddPeer(key, "abc.netbird")
|
||||
err = status.AddPeer(key, "abc.netbird", ip)
|
||||
|
||||
assert.Error(t, err, "should return error on duplicate")
|
||||
}
|
||||
|
||||
func TestGetPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "100.108.254.1"
|
||||
status := NewRecorder("https://mgm")
|
||||
err := status.AddPeer(key, "abc.netbird")
|
||||
err := status.AddPeer(key, "abc.netbird", ip)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
peerStatus, err := status.GetPeer(key)
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -20,7 +21,7 @@ var (
|
||||
)
|
||||
|
||||
type WGInterfaceStater interface {
|
||||
GetStats(key string) (configurer.WGStats, error)
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
}
|
||||
|
||||
type WGWatcher struct {
|
||||
@@ -146,9 +147,13 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
||||
}
|
||||
|
||||
func (w *WGWatcher) wgState() (time.Time, error) {
|
||||
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
|
||||
wgStates, err := w.wgIfaceStater.GetStats()
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
wgState, ok := wgStates[w.peerKey]
|
||||
if !ok {
|
||||
return time.Time{}, fmt.Errorf("peer %s not found in WireGuard endpoints", w.peerKey)
|
||||
}
|
||||
return wgState.LastHandshake, nil
|
||||
}
|
||||
|
||||
@@ -11,26 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type MocWgIface struct {
|
||||
initial bool
|
||||
lastHandshake time.Time
|
||||
stop bool
|
||||
stop bool
|
||||
}
|
||||
|
||||
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
|
||||
if !m.initial {
|
||||
m.initial = true
|
||||
return configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
if !m.stop {
|
||||
m.lastHandshake = time.Now()
|
||||
}
|
||||
|
||||
stats := configurer.WGStats{
|
||||
LastHandshake: m.lastHandshake,
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
func (m *MocWgIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return map[string]configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
func (m *MocWgIface) disconnect() {
|
||||
|
||||
55
client/internal/peer/worker/state.go
Normal file
55
client/internal/peer/worker/state.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusDisconnected Status = iota
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
type Status int32
|
||||
|
||||
func (s Status) String() string {
|
||||
switch s {
|
||||
case StatusDisconnected:
|
||||
return "Disconnected"
|
||||
case StatusConnected:
|
||||
return "Connected"
|
||||
default:
|
||||
log.Errorf("unknown status: %d", s)
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// AtomicWorkerStatus is a thread-safe wrapper for worker status
|
||||
type AtomicWorkerStatus struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
func NewAtomicStatus() *AtomicWorkerStatus {
|
||||
acs := &AtomicWorkerStatus{}
|
||||
acs.SetDisconnected()
|
||||
return acs
|
||||
}
|
||||
|
||||
// Get returns the current connection status
|
||||
func (acs *AtomicWorkerStatus) Get() Status {
|
||||
return Status(acs.status.Load())
|
||||
}
|
||||
|
||||
func (acs *AtomicWorkerStatus) SetConnected() {
|
||||
acs.status.Store(int32(StatusConnected))
|
||||
}
|
||||
|
||||
func (acs *AtomicWorkerStatus) SetDisconnected() {
|
||||
acs.status.Store(int32(StatusDisconnected))
|
||||
}
|
||||
|
||||
// String returns the string representation of the current status
|
||||
func (acs *AtomicWorkerStatus) String() string {
|
||||
return acs.Get().String()
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -397,10 +398,10 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
||||
func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
|
||||
if isRelayed(pair) {
|
||||
return connPriorityICETurn
|
||||
return conntype.ICETurn
|
||||
} else {
|
||||
return connPriorityICEP2P
|
||||
return conntype.ICEP2P
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package peerstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -79,6 +80,32 @@ func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
|
||||
return p, true
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// this can be blocked because of the connect open limiter semaphore
|
||||
if err := p.Open(ctx); err != nil {
|
||||
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnClose(pubKey string) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func (s *Store) PeersPubKey() []string {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/client/common"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
@@ -41,6 +42,8 @@ type PKCEAuthProviderConfig struct {
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
@@ -100,6 +103,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package iface
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -18,5 +17,4 @@ type wgIfaceBase interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,10 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
// filter out domain routes
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
nets = append(nets, r.Network.String())
|
||||
}
|
||||
sort.Strings(nets)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
@@ -1,27 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
}
|
||||
|
||||
func (r serverRouter) cleanUp() {
|
||||
}
|
||||
|
||||
func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -94,7 +94,7 @@ message LoginRequest {
|
||||
|
||||
bytes customDNSAddress = 7;
|
||||
|
||||
bool isLinuxDesktopClient = 8;
|
||||
bool isUnixDesktopClient = 8;
|
||||
|
||||
string hostname = 9;
|
||||
|
||||
@@ -134,6 +134,7 @@ message LoginRequest {
|
||||
// omits initialized empty slices due to omitempty tags
|
||||
bool cleanDNSLabels = 27;
|
||||
|
||||
optional bool lazyConnectionEnabled = 28;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -274,6 +275,8 @@ message FullStatus {
|
||||
int32 NumberOfForwardingRules = 8;
|
||||
|
||||
repeated SystemEvent events = 7;
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
}
|
||||
|
||||
// Networks
|
||||
|
||||
@@ -139,6 +139,7 @@ func (s *Server) Start() error {
|
||||
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
if s.sessionWatcher == nil {
|
||||
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
|
||||
@@ -417,6 +418,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
s.latestConfigInput.DisableNotifications = msg.DisableNotifications
|
||||
}
|
||||
|
||||
if msg.LazyConnectionEnabled != nil {
|
||||
inputConfig.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
s.latestConfigInput.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
@@ -446,7 +452,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
if msg.SetupKey == "" {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, err
|
||||
@@ -804,6 +810,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
|
||||
@@ -97,6 +97,7 @@ type OutputOverview struct {
|
||||
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
|
||||
NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
Events []SystemEventOutput `json:"events" yaml:"events"`
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
}
|
||||
|
||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
|
||||
@@ -136,6 +137,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
||||
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
Events: mapEvents(pbFullStatus.GetEvents()),
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
}
|
||||
|
||||
if anon {
|
||||
@@ -206,7 +208,7 @@ func mapPeers(
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
@@ -384,6 +386,11 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
}
|
||||
|
||||
lazyConnectionEnabledStatus := "false"
|
||||
if overview.LazyConnectionEnabled {
|
||||
lazyConnectionEnabledStatus = "true"
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
@@ -405,6 +412,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Forwarding rules: %d\n"+
|
||||
"Peers count: %s\n",
|
||||
@@ -419,6 +427,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
networks,
|
||||
overview.NumberOfForwardingRules,
|
||||
peersCountString,
|
||||
@@ -533,23 +542,13 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(
|
||||
peerState *proto.PeerState,
|
||||
isConnected bool,
|
||||
statusFilter string,
|
||||
prefixNamesFilter []string,
|
||||
prefixNamesFilterMap map[string]struct{},
|
||||
ipsFilter map[string]struct{},
|
||||
) bool {
|
||||
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
if lowerStatusFilter == "disconnected" && isConnected {
|
||||
statusEval = true
|
||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
||||
if !strings.EqualFold(peerStatus, statusFilter) {
|
||||
statusEval = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -383,7 +383,8 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"error": "timeout"
|
||||
}
|
||||
],
|
||||
"events": []
|
||||
"events": [],
|
||||
"lazyConnectionEnabled": false
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
@@ -484,6 +485,7 @@ dnsServers:
|
||||
enabled: false
|
||||
error: timeout
|
||||
events: []
|
||||
lazyConnectionEnabled: false
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -548,6 +550,7 @@ FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
Networks: 10.10.0.0/24
|
||||
Forwarding rules: 0
|
||||
Peers count: 2/2 Connected
|
||||
@@ -570,6 +573,7 @@ FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
Networks: 10.10.0.0/24
|
||||
Forwarding rules: 0
|
||||
Peers count: 2/2 Connected
|
||||
|
||||
@@ -62,6 +62,8 @@ func main() {
|
||||
return
|
||||
}
|
||||
logFile = file
|
||||
} else {
|
||||
_ = util.InitLog("trace", "console")
|
||||
}
|
||||
|
||||
// Create the Fyne application.
|
||||
@@ -191,6 +193,7 @@ type serviceClient struct {
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mLazyConnEnabled *systray.MenuItem
|
||||
mNotifications *systray.MenuItem
|
||||
mAdvancedSettings *systray.MenuItem
|
||||
mCreateDebugBundle *systray.MenuItem
|
||||
@@ -382,12 +385,12 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
s.adminURL = iAdminURL
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
ManagementUrl: iMngURL,
|
||||
AdminURL: iAdminURL,
|
||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
|
||||
InterfaceName: &s.iInterfaceName.Text,
|
||||
WireguardPort: &port,
|
||||
ManagementUrl: iMngURL,
|
||||
AdminURL: iAdminURL,
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
|
||||
InterfaceName: &s.iInterfaceName.Text,
|
||||
WireguardPort: &port,
|
||||
}
|
||||
|
||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
@@ -414,7 +417,7 @@ func (s *serviceClient) login() error {
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("login to management URL with: %v", err)
|
||||
@@ -630,6 +633,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable lazy connection", lazyConnMenuDescr, false)
|
||||
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
|
||||
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
|
||||
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
|
||||
@@ -689,104 +693,114 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.mUp.ClickedCh:
|
||||
s.mUp.Disable()
|
||||
go func() {
|
||||
defer s.mUp.Enable()
|
||||
err := s.menuUpClick()
|
||||
if err != nil {
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mDown.ClickedCh:
|
||||
s.mDown.Disable()
|
||||
go func() {
|
||||
defer s.mDown.Enable()
|
||||
err := s.menuDownClick()
|
||||
if err != nil {
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mAllowSSH.ClickedCh:
|
||||
if s.mAllowSSH.Checked() {
|
||||
s.mAllowSSH.Uncheck()
|
||||
} else {
|
||||
s.mAllowSSH.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mAutoConnect.ClickedCh:
|
||||
if s.mAutoConnect.Checked() {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
s.mAutoConnect.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mEnableRosenpass.ClickedCh:
|
||||
if s.mEnableRosenpass.Checked() {
|
||||
s.mEnableRosenpass.Uncheck()
|
||||
} else {
|
||||
s.mEnableRosenpass.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mAdvancedSettings.ClickedCh:
|
||||
s.mAdvancedSettings.Disable()
|
||||
go func() {
|
||||
defer s.mAdvancedSettings.Enable()
|
||||
defer s.getSrvConfig()
|
||||
s.runSelfCommand("settings", "true")
|
||||
}()
|
||||
case <-s.mCreateDebugBundle.ClickedCh:
|
||||
s.mCreateDebugBundle.Disable()
|
||||
go func() {
|
||||
defer s.mCreateDebugBundle.Enable()
|
||||
s.runSelfCommand("debug", "true")
|
||||
}()
|
||||
case <-s.mQuit.ClickedCh:
|
||||
systray.Quit()
|
||||
return
|
||||
case <-s.mGitHub.ClickedCh:
|
||||
err := openURL("https://github.com/netbirdio/netbird")
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
}
|
||||
case <-s.mUpdate.ClickedCh:
|
||||
err := openURL(version.DownloadUrl())
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
}
|
||||
case <-s.mNetworks.ClickedCh:
|
||||
s.mNetworks.Disable()
|
||||
go func() {
|
||||
defer s.mNetworks.Enable()
|
||||
s.runSelfCommand("networks", "true")
|
||||
}()
|
||||
case <-s.mNotifications.ClickedCh:
|
||||
if s.mNotifications.Checked() {
|
||||
s.mNotifications.Uncheck()
|
||||
} else {
|
||||
s.mNotifications.Check()
|
||||
}
|
||||
if s.eventManager != nil {
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
}
|
||||
go s.listenEvents()
|
||||
}
|
||||
|
||||
func (s *serviceClient) listenEvents() {
|
||||
for {
|
||||
select {
|
||||
case <-s.mUp.ClickedCh:
|
||||
s.mUp.Disable()
|
||||
go func() {
|
||||
defer s.mUp.Enable()
|
||||
err := s.menuUpClick()
|
||||
if err != nil {
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mDown.ClickedCh:
|
||||
s.mDown.Disable()
|
||||
go func() {
|
||||
defer s.mDown.Enable()
|
||||
err := s.menuDownClick()
|
||||
if err != nil {
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mAllowSSH.ClickedCh:
|
||||
if s.mAllowSSH.Checked() {
|
||||
s.mAllowSSH.Uncheck()
|
||||
} else {
|
||||
s.mAllowSSH.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mAutoConnect.ClickedCh:
|
||||
if s.mAutoConnect.Checked() {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
s.mAutoConnect.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mEnableRosenpass.ClickedCh:
|
||||
if s.mEnableRosenpass.Checked() {
|
||||
s.mEnableRosenpass.Uncheck()
|
||||
} else {
|
||||
s.mEnableRosenpass.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mLazyConnEnabled.ClickedCh:
|
||||
if s.mLazyConnEnabled.Checked() {
|
||||
s.mLazyConnEnabled.Uncheck()
|
||||
} else {
|
||||
s.mLazyConnEnabled.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
case <-s.mAdvancedSettings.ClickedCh:
|
||||
s.mAdvancedSettings.Disable()
|
||||
go func() {
|
||||
defer s.mAdvancedSettings.Enable()
|
||||
defer s.getSrvConfig()
|
||||
s.runSelfCommand("settings", "true")
|
||||
}()
|
||||
case <-s.mCreateDebugBundle.ClickedCh:
|
||||
s.mCreateDebugBundle.Disable()
|
||||
go func() {
|
||||
defer s.mCreateDebugBundle.Enable()
|
||||
s.runSelfCommand("debug", "true")
|
||||
}()
|
||||
case <-s.mQuit.ClickedCh:
|
||||
systray.Quit()
|
||||
return
|
||||
case <-s.mGitHub.ClickedCh:
|
||||
err := openURL("https://github.com/netbirdio/netbird")
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
}
|
||||
case <-s.mUpdate.ClickedCh:
|
||||
err := openURL(version.DownloadUrl())
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
}
|
||||
case <-s.mNetworks.ClickedCh:
|
||||
s.mNetworks.Disable()
|
||||
go func() {
|
||||
defer s.mNetworks.Enable()
|
||||
s.runSelfCommand("networks", "true")
|
||||
}()
|
||||
case <-s.mNotifications.ClickedCh:
|
||||
if s.mNotifications.Checked() {
|
||||
s.mNotifications.Uncheck()
|
||||
} else {
|
||||
s.mNotifications.Check()
|
||||
}
|
||||
if s.eventManager != nil {
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) runSelfCommand(command, arg string) {
|
||||
@@ -1018,13 +1032,15 @@ func (s *serviceClient) updateConfig() error {
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
notificationsDisabled := !s.mNotifications.Checked()
|
||||
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
DisableNotifications: ¬ificationsDisabled,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
}
|
||||
|
||||
if err := s.restartClient(&loginRequest); err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ const (
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
lazyConnMenuDescr = "[Experimental] Enable lazy connect"
|
||||
notificationsMenuDescr = "Enable notifications"
|
||||
advancedSettingsMenuDescr = "Advanced settings of the application"
|
||||
debugBundleMenuDescr = "Create and open debug information bundle"
|
||||
|
||||
4
go.mod
4
go.mod
@@ -59,13 +59,12 @@ require (
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -195,6 +194,7 @@ require (
|
||||
github.com/libdns/libdns v0.2.2 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
||||
github.com/mholt/acmez/v2 v2.0.1 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -507,8 +507,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
|
||||
@@ -59,6 +59,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
||||
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
|
||||
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
|
||||
# Dashboard
|
||||
@@ -122,6 +123,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
|
||||
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
|
||||
export NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
||||
export NETBIRD_AUTH_PKCE_AUDIENCE
|
||||
export NETBIRD_DASH_AUTH_USE_AUDIENCE
|
||||
export NETBIRD_DASH_AUTH_AUDIENCE
|
||||
|
||||
@@ -95,7 +95,8 @@
|
||||
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
|
||||
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
|
||||
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
|
||||
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
|
||||
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
|
||||
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,3 +28,4 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH
|
||||
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
||||
NETBIRD_RELAY_PORT=33445
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
|
||||
|
||||
19
management/client/common/types.go
Normal file
19
management/client/common/types.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package common
|
||||
|
||||
// LoginFlag introduces additional login flags to the PKCE authorization request
|
||||
type LoginFlag uint8
|
||||
|
||||
const (
|
||||
// LoginFlagPrompt adds prompt=login to the authorization request
|
||||
LoginFlagPrompt LoginFlag = iota
|
||||
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
|
||||
LoginFlagMaxAge0
|
||||
)
|
||||
|
||||
func (l LoginFlag) IsPromptLogin() bool {
|
||||
return l == LoginFlagPrompt
|
||||
}
|
||||
|
||||
func (l LoginFlag) IsMaxAge0Login() bool {
|
||||
return l == LoginFlagMaxAge0
|
||||
}
|
||||
@@ -260,8 +260,6 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
|
||||
|
||||
if err := msgHandler(decryptedResp); err != nil {
|
||||
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
|
||||
// hide any grpc error code that is not relevant for management
|
||||
return fmt.Errorf("msg handler error: %v", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type AccountsAPI struct {
|
||||
// List list all accounts, only returns one account always
|
||||
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
|
||||
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/accounts", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
|
||||
// Delete delete account
|
||||
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
|
||||
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -66,6 +66,15 @@ func TestAccounts_List_Err(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_List_ConnErr(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
ret, err := c.Accounts.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "404")
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
type Client struct {
|
||||
managementURL string
|
||||
authHeader string
|
||||
httpClient HttpClient
|
||||
|
||||
// Accounts NetBird account APIs
|
||||
// see more: https://docs.netbird.io/api/resources/accounts
|
||||
@@ -70,20 +72,29 @@ type Client struct {
|
||||
|
||||
// New initialize new Client instance using PAT token
|
||||
func New(managementURL, token string) *Client {
|
||||
client := &Client{
|
||||
managementURL: managementURL,
|
||||
authHeader: "Token " + token,
|
||||
}
|
||||
client.initialize()
|
||||
return client
|
||||
return NewWithOptions(
|
||||
WithManagementURL(managementURL),
|
||||
WithPAT(token),
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithBearerToken initialize new Client instance using Bearer token type
|
||||
func NewWithBearerToken(managementURL, token string) *Client {
|
||||
return NewWithOptions(
|
||||
WithManagementURL(managementURL),
|
||||
WithBearerToken(token),
|
||||
)
|
||||
}
|
||||
|
||||
func NewWithOptions(opts ...option) *Client {
|
||||
client := &Client{
|
||||
managementURL: managementURL,
|
||||
authHeader: "Bearer " + token,
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
for _, option := range opts {
|
||||
option(client)
|
||||
}
|
||||
|
||||
client.initialize()
|
||||
return client
|
||||
}
|
||||
@@ -104,7 +115,7 @@ func (c *Client) initialize() {
|
||||
c.Events = &EventsAPI{c}
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -116,7 +127,7 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -124,7 +135,8 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
|
||||
if resp.StatusCode > 299 {
|
||||
parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
|
||||
if pErr != nil {
|
||||
return nil, err
|
||||
|
||||
return nil, pErr
|
||||
}
|
||||
return nil, errors.New(parsedErr.Message)
|
||||
}
|
||||
@@ -135,13 +147,16 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
|
||||
func parseResponse[T any](resp *http.Response) (T, error) {
|
||||
var ret T
|
||||
if resp.Body == nil {
|
||||
return ret, errors.New("No body")
|
||||
return ret, fmt.Errorf("Body missing, HTTP Error code %d", resp.StatusCode)
|
||||
}
|
||||
bs, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
err = json.Unmarshal(bs, &ret)
|
||||
if err != nil {
|
||||
return ret, fmt.Errorf("Error code %d, error unmarshalling body: %w", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
return ret, err
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type DNSAPI struct {
|
||||
// ListNameserverGroups list all nameserver groups
|
||||
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
|
||||
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
|
||||
// GetNameserverGroup get nameserver group info
|
||||
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
|
||||
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
|
||||
// DeleteNameserverGroup delete nameserver group
|
||||
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
|
||||
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
|
||||
// GetSettings get DNS settings
|
||||
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
|
||||
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/settings", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ type EventsAPI struct {
|
||||
// List list all events
|
||||
// See more: https://docs.netbird.io/api/resources/events#list-all-events
|
||||
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/events", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ type GeoLocationAPI struct {
|
||||
// ListCountries list all country codes
|
||||
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
|
||||
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
|
||||
// ListCountryCities Get a list of all English city names for a given country code
|
||||
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
|
||||
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type GroupsAPI struct {
|
||||
// List list all groups
|
||||
// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
|
||||
func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/groups", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
|
||||
// Get get group info
|
||||
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
|
||||
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/groups/"+groupID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
|
||||
// Delete delete group
|
||||
// See more: https://docs.netbird.io/api/resources/groups#delete-a-group
|
||||
func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/groups/"+groupID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type NetworksAPI struct {
|
||||
// List list all networks
|
||||
// See more: https://docs.netbird.io/api/resources/networks#list-all-networks
|
||||
func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/networks", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
|
||||
// Get get network info
|
||||
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network
|
||||
func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+networkID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
|
||||
// Delete delete network
|
||||
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network
|
||||
func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/networks/"+networkID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,7 +108,7 @@ func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI {
|
||||
// List list all resources in networks
|
||||
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources
|
||||
func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource,
|
||||
// Get get network resource info
|
||||
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource
|
||||
func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
|
||||
// Delete delete network resource
|
||||
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource
|
||||
func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -200,7 +200,7 @@ func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI {
|
||||
// List list all routers in networks
|
||||
// See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers
|
||||
func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -214,7 +214,7 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro
|
||||
// Get get network router info
|
||||
// See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router
|
||||
func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
|
||||
// Delete delete network router
|
||||
// See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router
|
||||
func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
35
management/client/rest/options.go
Normal file
35
management/client/rest/options.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package rest
|
||||
|
||||
import "net/http"
|
||||
|
||||
type option func(*Client)
|
||||
|
||||
type HttpClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func WithHttpClient(client HttpClient) option {
|
||||
return func(c *Client) {
|
||||
c.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
func WithBearerToken(token string) option {
|
||||
return WithAuthHeader("Bearer " + token)
|
||||
}
|
||||
|
||||
func WithPAT(token string) option {
|
||||
return WithAuthHeader("Token " + token)
|
||||
}
|
||||
|
||||
func WithManagementURL(url string) option {
|
||||
return func(c *Client) {
|
||||
c.managementURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthHeader(value string) option {
|
||||
return func(c *Client) {
|
||||
c.authHeader = value
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@ type PeersAPI struct {
|
||||
// List list all peers
|
||||
// See more: https://docs.netbird.io/api/resources/peers#list-all-peers
|
||||
func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/peers", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
|
||||
// Get retrieve a peer
|
||||
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer
|
||||
func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/peers/"+peerID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
|
||||
// Delete delete a peer
|
||||
// See more: https://docs.netbird.io/api/resources/peers#delete-a-peer
|
||||
func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/peers/"+peerID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
|
||||
// ListAccessiblePeers list all peers that the specified peer can connect to within the network
|
||||
// See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers
|
||||
func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@ type PoliciesAPI struct {
|
||||
// List list all policies
|
||||
// See more: https://docs.netbird.io/api/resources/policies#list-all-policies
|
||||
func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/policies", nil)
|
||||
path := "/api/policies"
|
||||
|
||||
resp, err := a.c.NewRequest(ctx, "GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +32,7 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
|
||||
// Get get policy info
|
||||
// See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy
|
||||
func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/policies/"+policyID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +50,7 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,11 +64,13 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
|
||||
// Update update policy info
|
||||
// See more: https://docs.netbird.io/api/resources/policies#update-a-policy
|
||||
func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.PutApiPoliciesPolicyIdJSONRequestBody) (*api.Policy, error) {
|
||||
path := "/api/policies/" + policyID
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/policies/"+policyID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +84,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P
|
||||
// Delete delete policy
|
||||
// See more: https://docs.netbird.io/api/resources/policies#delete-a-policy
|
||||
func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/policies/"+policyID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type PostureChecksAPI struct {
|
||||
// List list all posture checks
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#list-all-posture-checks
|
||||
func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/posture-checks", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error)
|
||||
// Get get posture check info
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#retrieve-a-posture-check
|
||||
func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api.PostureCheck, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re
|
||||
// Delete delete posture check
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#delete-a-posture-check
|
||||
func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type RoutesAPI struct {
|
||||
// List list all routes
|
||||
// See more: https://docs.netbird.io/api/resources/routes#list-all-routes
|
||||
func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/routes", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
|
||||
// Get get route info
|
||||
// See more: https://docs.netbird.io/api/resources/routes#retrieve-a-route
|
||||
func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/routes/"+routeID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA
|
||||
// Delete delete route
|
||||
// See more: https://docs.netbird.io/api/resources/routes#delete-a-route
|
||||
func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/routes/"+routeID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type SetupKeysAPI struct {
|
||||
// List list all setup keys
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#list-all-setup-keys
|
||||
func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/setup-keys", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
|
||||
// Get get setup key info
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#retrieve-a-setup-key
|
||||
func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKey, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -44,11 +44,13 @@ func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKe
|
||||
// Create generate new Setup Key
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#create-a-setup-key
|
||||
func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJSONRequestBody) (*api.SetupKeyClear, error) {
|
||||
path := "/api/setup-keys"
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/setup-keys", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +68,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +82,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap
|
||||
// Delete delete setup key
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#delete-a-setup-key
|
||||
func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type TokensAPI struct {
|
||||
// List list user tokens
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#list-all-tokens
|
||||
func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAccessToken, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce
|
||||
// Get get user token info
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#retrieve-a-token
|
||||
func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.PersonalAccessToken, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA
|
||||
// Delete delete user token
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#delete-a-token
|
||||
func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type UsersAPI struct {
|
||||
// List list all users, only returns one user always
|
||||
// See more: https://docs.netbird.io/api/resources/users#list-all-users
|
||||
func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/users", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes))
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi
|
||||
// Delete delete user
|
||||
// See more: https://docs.netbird.io/api/resources/users#delete-a-user
|
||||
func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/users/"+userID, nil)
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
|
||||
// ResendInvitation resend user invitation
|
||||
// See more: https://docs.netbird.io/api/resources/users#resend-user-invitation
|
||||
func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
|
||||
resp, err := a.c.newRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
|
||||
// Current gets the current user info
|
||||
// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user
|
||||
func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
|
||||
resp, err := a.c.newRequest(ctx, "GET", "/api/users/current", nil)
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ var (
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
|
||||
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
|
||||
const maxDomains = 32
|
||||
|
||||
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
|
||||
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
|
||||
func ValidateDomains(domains []string) (List, error) {
|
||||
if len(domains) == 0 {
|
||||
@@ -17,8 +19,6 @@ func ValidateDomains(domains []string) (List, error) {
|
||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
}
|
||||
|
||||
domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
|
||||
var domainList List
|
||||
|
||||
for _, d := range domains {
|
||||
@@ -37,27 +37,20 @@ func ValidateDomains(domains []string) (List, error) {
|
||||
return domainList, nil
|
||||
}
|
||||
|
||||
// ValidateDomainsStrSlice checks if each domain in the list is valid
|
||||
func ValidateDomainsStrSlice(domains []string) ([]string, error) {
|
||||
// ValidateDomainsList checks if each domain in the list is valid
|
||||
func ValidateDomainsList(domains []string) error {
|
||||
if len(domains) == 0 {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
if len(domains) > maxDomains {
|
||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
}
|
||||
|
||||
domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
|
||||
var domainList []string
|
||||
|
||||
for _, d := range domains {
|
||||
d := strings.ToLower(d)
|
||||
|
||||
if !domainRegex.MatchString(d) {
|
||||
return domainList, fmt.Errorf("invalid domain format: %s", d)
|
||||
return fmt.Errorf("invalid domain format: %s", d)
|
||||
}
|
||||
|
||||
domainList = append(domainList, d)
|
||||
}
|
||||
return domainList, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -97,110 +97,89 @@ func TestValidateDomains(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateDomainsStrSlice tests the ValidateDomainsStrSlice function.
|
||||
func TestValidateDomainsStrSlice(t *testing.T) {
|
||||
// Generate a slice of valid domains up to maxDomains
|
||||
func TestValidateDomainsList(t *testing.T) {
|
||||
validDomains := make([]string, maxDomains)
|
||||
for i := 0; i < maxDomains; i++ {
|
||||
for i := range maxDomains {
|
||||
validDomains[i] = fmt.Sprintf("example%d.com", i)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expected []string
|
||||
wantErr bool
|
||||
name string
|
||||
domains []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty list",
|
||||
domains: nil,
|
||||
expected: nil,
|
||||
wantErr: false,
|
||||
name: "Empty list",
|
||||
domains: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Single valid ASCII domain",
|
||||
domains: []string{"sub.ex-ample.com"},
|
||||
expected: []string{"sub.ex-ample.com"},
|
||||
wantErr: false,
|
||||
name: "Single valid ASCII domain",
|
||||
domains: []string{"sub.ex-ample.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Underscores in labels",
|
||||
domains: []string{"_jabber._tcp.gmail.com"},
|
||||
expected: []string{"_jabber._tcp.gmail.com"},
|
||||
wantErr: false,
|
||||
name: "Underscores in labels",
|
||||
domains: []string{"_jabber._tcp.gmail.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// Unlike ValidateDomains (which converts to punycode),
|
||||
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
|
||||
name: "Unicode domain fails (no punycode conversion)",
|
||||
domains: []string{"münchen.de"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
name: "Unicode domain fails (no punycode conversion)",
|
||||
domains: []string{"münchen.de"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format - leading dash",
|
||||
domains: []string{"-example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
name: "Invalid domain format - leading dash",
|
||||
domains: []string{"-example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format - trailing dash",
|
||||
domains: []string{"example-.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
name: "Invalid domain format - trailing dash",
|
||||
domains: []string{"example-.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// The function stops on the first invalid domain and returns an error,
|
||||
// so only the first domain is definitely valid, but the second is invalid.
|
||||
name: "Multiple domains with a valid one, then invalid",
|
||||
domains: []string{"google.com", "invalid_domain.com-"},
|
||||
expected: []string{"google.com"},
|
||||
wantErr: true,
|
||||
name: "Multiple domains with a valid one, then invalid",
|
||||
domains: []string{"google.com", "invalid_domain.com-"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid wildcard domain",
|
||||
domains: []string{"*.example.com"},
|
||||
expected: []string{"*.example.com"},
|
||||
wantErr: false,
|
||||
name: "Valid wildcard domain",
|
||||
domains: []string{"*.example.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Wildcard with leading dot - invalid",
|
||||
domains: []string{".*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
name: "Wildcard with leading dot - invalid",
|
||||
domains: []string{".*.example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid wildcard with multiple asterisks",
|
||||
domains: []string{"a.*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
name: "Invalid wildcard with multiple asterisks",
|
||||
domains: []string{"a.*.example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Exactly maxDomains items (valid)",
|
||||
domains: validDomains,
|
||||
expected: validDomains,
|
||||
wantErr: false,
|
||||
name: "Exactly maxDomains items (valid)",
|
||||
domains: validDomains,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Exceeds maxDomains items",
|
||||
domains: append(validDomains, "extra.com"),
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
name: "Exceeds maxDomains items",
|
||||
domains: append(validDomains, "extra.com"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ValidateDomainsStrSlice(tt.domains)
|
||||
// Check if we got an error where expected
|
||||
err := ValidateDomainsList(tt.domains)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Compare the returned domains to what we expect
|
||||
assert.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user