mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 07:04:17 -04:00
Compare commits
102 Commits
feature/re
...
refactor/n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
575b176371 | ||
|
|
b96dfb834e | ||
|
|
a654c286bf | ||
|
|
20c65cffc2 | ||
|
|
95cf5ab673 | ||
|
|
ed0c5a470a | ||
|
|
3a11175d34 | ||
|
|
d82176e9ea | ||
|
|
56ebd004d9 | ||
|
|
1ee575befe | ||
|
|
d3a34adcc9 | ||
|
|
bbb0910a2c | ||
|
|
ec035e5709 | ||
|
|
d7321c130b | ||
|
|
404cab90ba | ||
|
|
4545ab9a52 | ||
|
|
7f08983207 | ||
|
|
eddea14521 | ||
|
|
b9ef214ea5 | ||
|
|
709e24eb6f | ||
|
|
6654e2dbf7 | ||
|
|
d80d47a469 | ||
|
|
96f71ff1e1 | ||
|
|
2fe2af38d2 | ||
|
|
cd9a867ad0 | ||
|
|
0f9bfeff7c | ||
|
|
d0f8868e4d | ||
|
|
31e3928e08 | ||
|
|
1793ad8ff0 | ||
|
|
f5301230bf | ||
|
|
7e51803a58 | ||
|
|
4cab69a76c | ||
|
|
dace4ff30c | ||
|
|
429d7d6585 | ||
|
|
3cdb10cde7 | ||
|
|
069586b77f | ||
|
|
dac6fb16ba | ||
|
|
cd7dce7678 | ||
|
|
9bdabfbb2c | ||
|
|
2760c78bc2 | ||
|
|
af95aabb03 | ||
|
|
3abae0bd17 | ||
|
|
8252ff41db | ||
|
|
277aa2b7cc | ||
|
|
bb37dc89ce | ||
|
|
2882638c7f | ||
|
|
000e99e7f3 | ||
|
|
0d2e67983a | ||
|
|
5151f19d29 | ||
|
|
bedd3cabc9 | ||
|
|
9c7c15ba8b | ||
|
|
b0dcc9ff7b | ||
|
|
d35a845dbd | ||
|
|
11e2573747 | ||
|
|
4e03f708a4 | ||
|
|
6ac9e58911 | ||
|
|
c115434d53 | ||
|
|
3984f0ee08 | ||
|
|
2dc5e7eb7a | ||
|
|
75dc7f6517 | ||
|
|
409e2013cc | ||
|
|
654aa9581d | ||
|
|
f13c6ba1b8 | ||
|
|
fc2f23bd22 | ||
|
|
26f5aee4b9 | ||
|
|
8d47175cdb | ||
|
|
9021bb512b | ||
|
|
768332820e | ||
|
|
229c65ffa1 | ||
|
|
4d33567888 | ||
|
|
274711a37e | ||
|
|
53e24ae7f7 | ||
|
|
fbc02343e9 | ||
|
|
ffed4b38ef | ||
|
|
88467883fc | ||
|
|
954f40991f | ||
|
|
34341d95a9 | ||
|
|
e926ca34b5 | ||
|
|
5d1c61369d | ||
|
|
fd9e21a5f3 | ||
|
|
841bc7564a | ||
|
|
f20a1b3328 | ||
|
|
2ac0da6cac | ||
|
|
148b8b04b3 | ||
|
|
9a56883ffb | ||
|
|
806be13dd5 | ||
|
|
90557da237 | ||
|
|
1d6209841e | ||
|
|
8f0e5708d5 | ||
|
|
5a9aa55121 | ||
|
|
6082c7cdcb | ||
|
|
06eae13352 | ||
|
|
08fba9876b | ||
|
|
ca85aa9b8f | ||
|
|
0ae2241573 | ||
|
|
050c05164a | ||
|
|
333908d06e | ||
|
|
bc6c5ece6e | ||
|
|
fd7b3ae21c | ||
|
|
abd7a84a46 | ||
|
|
f4b2bed1b9 | ||
|
|
2fb971e88a |
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.0
|
||||
FROM alpine:3.22.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
statusResp, err := getStatus(cmd.Context(), true)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
if err := openBrowser(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||
func openBrowser(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
return open.Run(url)
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx)
|
||||
resp, err := getStatus(ctx, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ""
|
||||
}
|
||||
|
||||
// Include action in the ipset name to prevent squashing rules with different actions
|
||||
actionSuffix := ""
|
||||
if action == firewall.ActionDrop {
|
||||
actionSuffix = "-drop"
|
||||
|
||||
@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
ruleInfo := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = ruleInfo.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -151,14 +151,20 @@ type Manager interface {
|
||||
|
||||
DisableRouting() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
// DeleteDNATRule deletes the outbound DNAT rule.
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -22,6 +22,8 @@ type BaseConnTrack struct {
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
|
||||
DNATOrigPort atomic.Uint32
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
|
||||
@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
|
||||
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
|
||||
@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
|
||||
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
|
||||
if exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
|
||||
@@ -109,6 +109,10 @@ type Manager struct {
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -122,6 +126,8 @@ type decoder struct {
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
@@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
|
||||
return false
|
||||
@@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite UDP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite TCP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
}
|
||||
@@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
@@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
|
||||
@@ -50,6 +50,8 @@ type logMessage struct {
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = (*buf)[:0]
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
|
||||
case 6:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
|
||||
case 7:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
|
||||
case 8:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
|
||||
}
|
||||
|
||||
*buf = append(*buf, formatted...)
|
||||
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -13,6 +15,21 @@ import (
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
var (
|
||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||
)
|
||||
|
||||
const (
|
||||
// Port offsets in TCP/UDP headers
|
||||
sourcePortOffset = 0
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
)
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// icmpChecksum calculates ICMP checksum.
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// biDNATMap maintains bidirectional DNAT mappings.
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
// portDNATRule represents a port-specific DNAT rule.
|
||||
type portDNATRule struct {
|
||||
protocol gopacket.LayerType
|
||||
origPort uint16
|
||||
targetPort uint16
|
||||
targetIP netip.Addr
|
||||
}
|
||||
|
||||
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
|
||||
}
|
||||
}
|
||||
|
||||
// set adds a bidirectional DNAT mapping between original and translated addresses.
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
// delete removes a bidirectional DNAT mapping for the given original address.
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
// getTranslated returns the translated address for a given original address.
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// getOriginal returns the original address for a given translated address.
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
if !originalAddr.IsValid() {
|
||||
return fmt.Errorf("invalid original IP address")
|
||||
}
|
||||
if !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid translated IP address")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
// getDNATTranslation returns the translated address if a mapping exists.
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
// findReverseDNATMapping finds original address for return traffic.
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets.
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error1("Failed to rewrite packet destination: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic.
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error1("Failed to rewrite packet source: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
if !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateICMPChecksum recalculates ICMP checksum after packet modification.
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
// DeleteDNATRule deletes outbound DNAT rule.
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// addPortRedirection adds a port redirection rule.
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
rule := portDNATRule{
|
||||
protocol: protocol,
|
||||
origPort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
}
|
||||
|
||||
m.portDNATRules = append(m.portDNATRules, rule)
|
||||
m.portDNATEnabled.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
layerType = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
layerType = layers.LayerTypeUDP
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// removePortRedirection removes a port redirection rule.
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
||||
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
|
||||
})
|
||||
|
||||
if len(m.portDNATRules) == 0 {
|
||||
m.portDNATEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
layerType = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
layerType = layers.LayerTypeUDP
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
|
||||
|
||||
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATRules {
|
||||
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.origPort != port {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
|
||||
portStart := tcpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
portStart := udpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
if len(packetData) >= udpStart+8 {
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
if oldChecksum != 0 {
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPortDNAT measures the performance of port DNAT operations
|
||||
func BenchmarkPortDNAT(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
useMatchPort bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "TCP inbound port DNAT translation (22 → 22022)",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "UDP inbound port DNAT translation (5353 → 22054)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
var origPort, targetPort, testPort uint16
|
||||
if sc.proto == layers.IPProtocolTCP {
|
||||
origPort, targetPort = 22, 22022
|
||||
} else {
|
||||
origPort, targetPort = 5353, 22054
|
||||
}
|
||||
|
||||
if sc.useMatchPort {
|
||||
testPort = origPort
|
||||
} else {
|
||||
testPort = 443 // Different port
|
||||
}
|
||||
|
||||
// Setup port DNAT mapping if needed
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Pre-establish inbound connection for outbound reverse test
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
|
||||
manager.filterInbound(inboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Benchmark inbound DNAT translation
|
||||
b.Run("inbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time
|
||||
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
b.Run("outbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh return packet (from target port)
|
||||
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
sourcePort uint16
|
||||
targetPort uint16
|
||||
}{
|
||||
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
|
||||
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
|
||||
d := parsePacket(t, inboundPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
|
||||
require.True(t, translated, "Inbound packet should be translated")
|
||||
|
||||
d = parsePacket(t, inboundPacket)
|
||||
var dstPort uint16
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.IPProtocolUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
|
||||
|
||||
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
|
||||
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
|
||||
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
|
||||
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
d := parsePacket(t, packet)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
d = parsePacket(t, packet)
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.IPProtocolUDP:
|
||||
return firewall.ProtocolUDP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,25 +16,33 @@ type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageInboundPortDNAT
|
||||
StageInbound1to1NAT
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
StageOutbound1to1NAT
|
||||
StageOutboundPortReverse
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageReceived: "Received",
|
||||
StageInboundPortDNAT: "Inbound Port DNAT",
|
||||
StageInbound1to1NAT: "Inbound 1:1 NAT",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageOutbound1to1NAT: "Outbound 1:1 NAT",
|
||||
StageOutboundPortReverse: "Outbound DNAT Reverse",
|
||||
}[s]
|
||||
}
|
||||
|
||||
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
|
||||
m.handleOutboundDNAT(trace, packetData, d)
|
||||
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
}
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
trace.DestinationPort = m.getDestPort(d)
|
||||
}
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
protocol := d.decoded[1]
|
||||
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
var originalPort uint16
|
||||
if protocol == layers.LayerTypeTCP {
|
||||
originalPort = uint16(d.tcp.DstPort)
|
||||
} else {
|
||||
originalPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
|
||||
if translated {
|
||||
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
|
||||
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
|
||||
|
||||
protoStr := "TCP"
|
||||
if protocol == layers.LayerTypeUDP {
|
||||
protoStr = "UDP"
|
||||
}
|
||||
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
|
||||
trace.AddResult(StageInboundPortDNAT, msg, true)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
|
||||
trace.AddResult(StageInbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
|
||||
m.traceOutbound1to1NAT(trace, packetData, d)
|
||||
m.traceOutboundPortReverse(trace, packetData, d)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatMappings[dstIP]
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
|
||||
trace.AddResult(StageOutbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
var origPort uint16
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeTCP:
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
srcPort := uint16(d.udp.SrcPort)
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
default:
|
||||
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) getDestPort(d *decoder) uint16 {
|
||||
if len(d.decoded) < 2 {
|
||||
return 0
|
||||
}
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
|
||||
@@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
// for js, the outer websocket layer takes care of tls
|
||||
if tlsEnabled && runtime.GOOS != "js" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
@@ -37,9 +38,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}
|
||||
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
||||
InsecureSkipVerify: runtime.GOOS == "js",
|
||||
RootCAs: certPool,
|
||||
RootCAs: certPool,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the existing peer to preserve its allowed IPs
|
||||
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer: %w", err)
|
||||
}
|
||||
|
||||
removePeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
||||
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
||||
}
|
||||
|
||||
//Re-add the peer without the endpoint but same AllowedIPs
|
||||
reAddPeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
AllowedIPs: existingPeer.AllowedIPs,
|
||||
ReplaceAllowedIPs: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
||||
return fmt.Errorf(
|
||||
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
||||
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer key: %w", err)
|
||||
}
|
||||
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get IPC config: %w", err)
|
||||
}
|
||||
|
||||
// Parse current status to get allowed IPs for the peer
|
||||
stats, err := parseStatus(c.deviceName, ipcStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse IPC config: %w", err)
|
||||
}
|
||||
|
||||
var allowedIPs []net.IPNet
|
||||
found := false
|
||||
for _, peer := range stats.Peers {
|
||||
if peer.PublicKey == peerKey {
|
||||
allowedIPs = peer.AllowedIPs
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("peer %s not found", peerKey)
|
||||
}
|
||||
|
||||
// remove the peer from the WireGuard configuration
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
||||
}
|
||||
|
||||
// Build the peer config
|
||||
peer = wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: allowedIPs,
|
||||
}
|
||||
|
||||
config = wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
||||
return fmt.Errorf("remove endpoint address: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,4 +23,5 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *WGTunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
func routesToString(routes []string) string {
|
||||
return strings.Join(routes, ";")
|
||||
}
|
||||
|
||||
@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns nil for kernel mode devices
|
||||
func (t *TunKernelDevice) GetICEBind() EndpointManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type Bind interface {
|
||||
conn.Bind
|
||||
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
ActivityRecorder() *bind.ActivityRecorder
|
||||
EndpointManager
|
||||
}
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
|
||||
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||
return t.net
|
||||
}
|
||||
|
||||
// GetICEBind returns the bind instance
|
||||
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
|
||||
return t.bind
|
||||
}
|
||||
|
||||
@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
13
client/iface/device/endpoint_manager.go
Normal file
13
client/iface/device/endpoint_manager.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
|
||||
// Implemented by bind.ICEBind and bind.RelayBindJS.
|
||||
type EndpointManager interface {
|
||||
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
|
||||
RemoveEndpoint(fakeIP netip.Addr)
|
||||
}
|
||||
@@ -21,4 +21,5 @@ type WGConfigurer interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
RemoveEndpointAddress(peerKey string) error
|
||||
}
|
||||
|
||||
@@ -21,4 +21,5 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.tun == nil {
|
||||
return nil
|
||||
}
|
||||
return w.tun.GetICEBind()
|
||||
}
|
||||
|
||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||
func (w *WGIface) IsUserspaceBind() bool {
|
||||
return w.userspaceBind
|
||||
@@ -148,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Removing endpoint address: %s", peerKey)
|
||||
return w.configurer.RemoveEndpointAddress(peerKey)
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
|
||||
@@ -29,11 +29,6 @@ type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||
}
|
||||
|
||||
type protoMatch struct {
|
||||
ips map[string]int
|
||||
policyID []byte
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
|
||||
// if TCP protocol rules not squashed and SSH enabled
|
||||
// we add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP 22 port).
|
||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
||||
//
|
||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
||||
// but other has port definitions or has drop policy.
|
||||
func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
totalIPs++
|
||||
}
|
||||
}
|
||||
|
||||
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
// But we zeroed the IP's for protocol if:
|
||||
// 1. Any of the rule has DROP action type.
|
||||
// 2. Any of rule contains Port.
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||
|
||||
if hasPortRestrictions {
|
||||
// Don't squash rules with port restrictions
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
// store the first encountered PolicyID for this protocol
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
}
|
||||
|
||||
// special case, when we receive this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
ipset := protocols[r.Protocol].ips
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
}
|
||||
ipset[r.PeerIP] = i
|
||||
}
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
// calculate squash for different directions
|
||||
if r.Direction == mgmProto.RuleDirection_IN {
|
||||
addRuleToCalculationMap(i, r, in)
|
||||
} else {
|
||||
addRuleToCalculationMap(i, r, out)
|
||||
}
|
||||
}
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.RuleProtocol{
|
||||
mgmProto.RuleProtocol_ALL,
|
||||
mgmProto.RuleProtocol_ICMP,
|
||||
mgmProto.RuleProtocol_TCP,
|
||||
mgmProto.RuleProtocol_UDP,
|
||||
}
|
||||
|
||||
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
match, ok := matches[protocol]
|
||||
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
||||
// don't squash if :
|
||||
// 1. Rules not cover all peers in the network
|
||||
// 2. Rules cover only one peer in the network.
|
||||
continue
|
||||
}
|
||||
|
||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: direction,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: protocol,
|
||||
PolicyID: match.policyID,
|
||||
})
|
||||
squashedProtocols[protocol] = struct{}{}
|
||||
|
||||
if protocol == mgmProto.RuleProtocol_ALL {
|
||||
// if we have ALL traffic type squashed rule
|
||||
// it allows all other type of traffic, so we can stop processing
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
squash(in, mgmProto.RuleDirection_IN)
|
||||
squash(out, mgmProto.RuleDirection_OUT)
|
||||
|
||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
return squashedRules, squashedProtocols
|
||||
}
|
||||
|
||||
if len(squashedRules) == 0 {
|
||||
return networkMap.FirewallRules, squashedProtocols
|
||||
}
|
||||
|
||||
var rules []*mgmProto.FirewallRule
|
||||
// filter out rules which was squashed from final list
|
||||
// if we also have other not squashed rules.
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
||||
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
return append(rules, squashedRules...), squashedProtocols
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
|
||||
@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, 2, len(rules))
|
||||
|
||||
r := rules[0]
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
|
||||
r = rules[1]
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []*mgmProto.FirewallRule
|
||||
expectedCount int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "should not squash rules with port ranges",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with specific ports",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with legacy port field",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with legacy port field should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with DROP action",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with DROP action should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should squash rules without port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||
},
|
||||
{
|
||||
name: "mixed rules should not squash protocol with port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "TCP should not be squashed because one rule has port restrictions",
|
||||
},
|
||||
{
|
||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
// TCP rules with port restrictions - should NOT be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
// UDP rules without port restrictions - SHOULD be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: tt.rules,
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
|
||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||
|
||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||
if tt.expectedCount == 1 {
|
||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
block.prof: Block profiling information.
|
||||
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Adding state file from: %s", path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
|
||||
@@ -14,6 +14,9 @@ type WGIface interface {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addWgShow() error {
|
||||
if g.statusRecorder == nil {
|
||||
return fmt.Errorf("no status recorder available for wg show")
|
||||
}
|
||||
result, err := g.statusRecorder.PeersStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
err = s.recordSystemDNSSettings(true)
|
||||
if err != nil {
|
||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
searchDomains = append(searchDomains, "\"\"")
|
||||
err = s.addLocalDNS()
|
||||
if err != nil {
|
||||
log.Infof("failed to enable split DNS")
|
||||
if err := s.addLocalDNS(); err != nil {
|
||||
log.Warnf("failed to add local DNS: %v", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
}
|
||||
|
||||
for _, dConf := range config.Domains {
|
||||
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) string() string {
|
||||
return "scutil"
|
||||
}
|
||||
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
func (s *systemConfigurator) addLocalDNS() error {
|
||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||
log.Errorf("Unable to get system DNS configuration")
|
||||
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
||||
}
|
||||
}
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
|
||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
||||
}
|
||||
} else {
|
||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||
log.Info("Not enabling local DNS server")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
111
client/internal/dns/host_darwin_test.go
Normal file
111
client/internal/dns/host_darwin_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
defer func() {
|
||||
require.NoError(t, sm.Stop(context.Background()))
|
||||
}()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
||||
ServerPort: 53,
|
||||
RouteAll: true,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "example.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err := configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
defer func() {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
t.Logf("Key %s exists before cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
sm2 := statemanager.New(stateFile)
|
||||
sm2.RegisterState(&ShutdownState{})
|
||||
err = sm2.LoadState(&ShutdownState{})
|
||||
require.NoError(t, err)
|
||||
|
||||
state := sm2.GetState(&ShutdownState{})
|
||||
if state == nil {
|
||||
t.Skip("State not saved, skipping cleanup test")
|
||||
}
|
||||
|
||||
shutdownState, ok := state.(*ShutdownState)
|
||||
require.True(t, ok)
|
||||
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if strings.Contains(string(output), "No such key") {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return !strings.Contains(string(output), "No such key"), nil
|
||||
}
|
||||
|
||||
func removeTestDNSKey(key string) error {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
r.updateState(stateManager)
|
||||
|
||||
var searchDomains, matchDomains []string
|
||||
for _, dConf := range config.Domains {
|
||||
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
log.Errorf("cleanup old dns match policies: %s", err)
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
if err != nil {
|
||||
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
return fmt.Errorf("remove dns match policies: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
r.updateState(stateManager)
|
||||
|
||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
|
||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||
}
|
||||
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
|
||||
102
client/internal/dns/host_windows_test.go
Normal file
102
client/internal/dns/host_windows_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
k.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
CreatedKeys []string
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
for _, key := range s.CreatedKeys {
|
||||
manager.createdKeys[key] = struct{}{}
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
78
client/internal/dnsfwd/cache.go
Normal file
78
client/internal/dnsfwd/cache.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
ip4Addrs []netip.Addr
|
||||
ip6Addrs []netip.Addr
|
||||
}
|
||||
|
||||
func newCache() *cache {
|
||||
return &cache{
|
||||
records: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.records[normalizeDomain(domain)]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
return slices.Clone(entry.ip4Addrs), true
|
||||
case dns.TypeAAAA:
|
||||
return slices.Clone(entry.ip6Addrs), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
norm := normalizeDomain(domain)
|
||||
entry, exists := c.records[norm]
|
||||
if !exists {
|
||||
entry = &cacheEntry{}
|
||||
c.records[norm] = entry
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
entry.ip4Addrs = slices.Clone(addrs)
|
||||
case dns.TypeAAAA:
|
||||
entry.ip6Addrs = slices.Clone(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// unset removes cached entries for the given domain and request type.
|
||||
func (c *cache) unset(domain string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.records, normalizeDomain(domain))
|
||||
}
|
||||
|
||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||
// lowercase and fully-qualified (with trailing dot).
|
||||
func normalizeDomain(domain string) string {
|
||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||
return dns.Fqdn(strings.ToLower(domain))
|
||||
}
|
||||
86
client/internal/dnsfwd/cache_test.go
Normal file
86
client/internal/dnsfwd/cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||
t.Helper()
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse addr %s: %v", s, err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestCacheNormalization(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
// Mixed case, without trailing dot
|
||||
domainInput := "ExAmPlE.CoM"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||
|
||||
// Lookup with lower, with trailing dot
|
||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Lookup with different casing again
|
||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSeparateTypes(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
domain := "test.local"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||
|
||||
c.set(domain, 1 /* A */, ipv4)
|
||||
c.set(domain, 28 /* AAAA */, ipv6)
|
||||
|
||||
got4, ok4 := c.get(domain, 1)
|
||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||
}
|
||||
|
||||
got6, ok6 := c.get(domain, 28)
|
||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||
c := newCache()
|
||||
domain := "clone.test"
|
||||
|
||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||
c.set(domain, 1, src)
|
||||
|
||||
// Mutate source slice; cache should be unaffected
|
||||
src[0] = mustAddr(t, "9.9.9.9")
|
||||
|
||||
got, ok := c.get(domain, 1)
|
||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Mutate returned slice; internal cache should remain unchanged
|
||||
got[0] = mustAddr(t, "4.4.4.4")
|
||||
got2, ok2 := c.get(domain, 1)
|
||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := newCache()
|
||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type DNSForwarder struct {
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// remove cache entries for domains that no longer appear
|
||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
}
|
||||
|
||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||
// in the old list but not present in the new list.
|
||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||
if f.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newSet := make(map[string]struct{}, len(newEntries))
|
||||
for _, e := range newEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, e := range oldEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
pattern := e.Domain.PunycodeString()
|
||||
if _, ok := newSet[pattern]; !ok {
|
||||
f.cache.unset(pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
// Ensures that when the first query succeeds and populates the cache,
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First call resolves successfully and populates cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
// Second call fails upstream; forwarder should serve from cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
// First query: populate cache
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("9.8.7.6")
|
||||
|
||||
// Initial resolution with mixed case to populate cache
|
||||
mixedQuery := "ExAmPlE.CoM"
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -12,18 +14,14 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
listenPort uint16 = 5353
|
||||
listenPortMu sync.RWMutex
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTTL = 60 //seconds
|
||||
dnsTTL = 60
|
||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||
)
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
@@ -36,24 +34,30 @@ type ForwarderEntry struct {
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
localAddr netip.Addr
|
||||
serverPort uint16
|
||||
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
port uint16
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
listenPortMu.RLock()
|
||||
defer listenPortMu.RUnlock()
|
||||
return listenPort
|
||||
}
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
|
||||
serverPort := nbdns.ForwarderServerPort
|
||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||
serverPort = uint16(port)
|
||||
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
|
||||
} else {
|
||||
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
|
||||
}
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
port: port,
|
||||
localAddr: localAddr,
|
||||
serverPort: serverPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,13 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.port > 0 {
|
||||
listenPortMu.Lock()
|
||||
listenPort = m.port
|
||||
listenPortMu.Unlock()
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
}
|
||||
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
}
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -98,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||
}
|
||||
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
@@ -113,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
func (m *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []uint16{ListenPort()},
|
||||
Values: []uint16{m.serverPort},
|
||||
}
|
||||
|
||||
if m.firewall == nil {
|
||||
|
||||
@@ -203,8 +203,7 @@ type Engine struct {
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
|
||||
// dns forwarder port
|
||||
dnsFwdPort uint16
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -247,7 +246,7 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
dnsFwdPort: dnsfwd.ListenPort(),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
@@ -1060,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
@@ -1084,7 +1087,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
@@ -1208,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||
if forwarderPort == 0 {
|
||||
forwarderPort = nbdns.ForwarderClientPort
|
||||
}
|
||||
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
ForwarderPort: forwarderPort,
|
||||
}
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
@@ -1667,7 +1676,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes() bool {
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
@@ -1699,8 +1708,12 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
results := e.probeICE(stuns, turns)
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
@@ -1717,13 +1730,6 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
return append(
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1843,7 +1849,6 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
forwarderPort uint16,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
@@ -1860,20 +1865,17 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if e.dnsForwardMgr == nil {
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
case e.dnsFwdPort != forwarderPort:
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||
e.dnsFwdPort = forwarderPort
|
||||
|
||||
default:
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
@@ -1883,20 +1885,6 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
// stop and start the forwarder to apply the new port
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
|
||||
@@ -105,6 +105,10 @@ type MockWGIface struct {
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
|
||||
82
client/internal/lazyconn/activity/lazy_conn.go
Normal file
82
client/internal/lazyconn/activity/lazy_conn.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// lazyConn detects activity when WireGuard attempts to send packets.
|
||||
// It does not deliver packets, only signals that activity occurred.
|
||||
type lazyConn struct {
|
||||
activityCh chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// newLazyConn creates a new lazyConn for activity detection.
|
||||
func newLazyConn() *lazyConn {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &lazyConn{
|
||||
activityCh: make(chan struct{}, 1),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Read blocks until the connection is closed.
|
||||
func (c *lazyConn) Read(_ []byte) (n int, err error) {
|
||||
<-c.ctx.Done()
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Write signals activity detection when ICEBind routes packets to this endpoint.
|
||||
func (c *lazyConn) Write(b []byte) (n int, err error) {
|
||||
if c.ctx.Err() != nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
select {
|
||||
case c.activityCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// ActivityChan returns the channel that signals when activity is detected.
|
||||
func (c *lazyConn) ActivityChan() <-chan struct{} {
|
||||
return c.activityCh
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *lazyConn) Close() error {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address.
|
||||
func (c *lazyConn) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address.
|
||||
func (c *lazyConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines.
|
||||
func (c *lazyConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for future Write calls.
|
||||
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
127
client/internal/lazyconn/activity/listener_bind.go
Normal file
127
client/internal/lazyconn/activity/listener_bind.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
type bindProvider interface {
|
||||
GetBind() device.EndpointManager
|
||||
}
|
||||
|
||||
const (
|
||||
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
|
||||
// The actual routing is done via fakeIP in ICEBind, not by this port.
|
||||
lazyBindPort = 17473
|
||||
)
|
||||
|
||||
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
|
||||
type BindListener struct {
|
||||
wgIface WgInterface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
done sync.WaitGroup
|
||||
|
||||
lazyConn *lazyConn
|
||||
bind device.EndpointManager
|
||||
fakeIP netip.Addr
|
||||
}
|
||||
|
||||
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
|
||||
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
|
||||
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
|
||||
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive fake IP: %w", err)
|
||||
}
|
||||
|
||||
d := &BindListener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
bind: bind,
|
||||
fakeIP: fakeIP,
|
||||
}
|
||||
|
||||
if err := d.setupLazyConn(); err != nil {
|
||||
return nil, fmt.Errorf("setup lazy connection: %v", err)
|
||||
}
|
||||
|
||||
d.done.Add(1)
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
||||
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
||||
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
||||
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
||||
if len(allowedIPs) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
||||
}
|
||||
|
||||
ourNetwork := wgIface.Address().Network
|
||||
|
||||
var peerIP netip.Addr
|
||||
for _, allowedIP := range allowedIPs {
|
||||
ip := allowedIP.Addr()
|
||||
if !ip.Is4() {
|
||||
continue
|
||||
}
|
||||
if ourNetwork.Contains(ip) {
|
||||
peerIP = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !peerIP.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||
}
|
||||
|
||||
octets := peerIP.As4()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
|
||||
return fakeIP, nil
|
||||
}
|
||||
|
||||
func (d *BindListener) setupLazyConn() error {
|
||||
d.lazyConn = newLazyConn()
|
||||
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
|
||||
|
||||
endpoint := &net.UDPAddr{
|
||||
IP: d.fakeIP.AsSlice(),
|
||||
Port: lazyBindPort,
|
||||
}
|
||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
|
||||
}
|
||||
|
||||
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
|
||||
func (d *BindListener) ReadPackets() {
|
||||
select {
|
||||
case <-d.lazyConn.ActivityChan():
|
||||
d.peerCfg.Log.Infof("activity detected via LazyConn")
|
||||
case <-d.lazyConn.ctx.Done():
|
||||
d.peerCfg.Log.Infof("exit from activity listener")
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.lazyConn.Close()
|
||||
d.bind.RemoveEndpoint(d.fakeIP)
|
||||
d.done.Done()
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *BindListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
||||
|
||||
if err := d.lazyConn.Close(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
|
||||
}
|
||||
|
||||
d.done.Wait()
|
||||
}
|
||||
291
client/internal/lazyconn/activity/listener_bind_test.go
Normal file
291
client/internal/lazyconn/activity/listener_bind_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
func isBindListenerPlatform() bool {
|
||||
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||
}
|
||||
|
||||
// mockEndpointManager implements device.EndpointManager for testing
|
||||
type mockEndpointManager struct {
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
}
|
||||
|
||||
func newMockEndpointManager() *mockEndpointManager {
|
||||
return &mockEndpointManager{
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
m.endpoints[fakeIP] = conn
|
||||
}
|
||||
|
||||
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
delete(m.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
|
||||
return m.endpoints[fakeIP]
|
||||
}
|
||||
|
||||
// MockWGIfaceBind mocks WgInterface with bind support
|
||||
type MockWGIfaceBind struct {
|
||||
endpointMgr *mockEndpointManager
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) Address() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
|
||||
return m.endpointMgr
|
||||
}
|
||||
|
||||
func TestBindListener_Creation(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
|
||||
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
|
||||
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
|
||||
|
||||
_, ok := conn.(*lazyConn)
|
||||
assert.True(t, ok, "Registered endpoint should be a lazyConn")
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindListener_ActivityDetection(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
activityDetected := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(activityDetected)
|
||||
}()
|
||||
|
||||
fakeIP := listener.fakeIP
|
||||
conn := mockEndpointMgr.GetEndpoint(fakeIP)
|
||||
require.NotNil(t, conn, "Endpoint should be registered")
|
||||
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-activityDetected:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity detection")
|
||||
}
|
||||
|
||||
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
|
||||
}
|
||||
|
||||
func TestBindListener_Close(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
fakeIP := listener.fakeIP
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
|
||||
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
|
||||
}
|
||||
|
||||
func TestManager_BindMode(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
mgr := NewManager(mockIface)
|
||||
defer mgr.Close()
|
||||
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
err := mgr.MonitorPeerActivity(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
|
||||
require.True(t, exists, "Peer listener should be found")
|
||||
|
||||
bindListener, ok := listener.(*BindListener)
|
||||
require.True(t, ok, "Listener should be BindListener, got %T", listener)
|
||||
|
||||
fakeIP := bindListener.fakeIP
|
||||
conn := mockEndpointMgr.GetEndpoint(fakeIP)
|
||||
require.NotNil(t, conn, "Endpoint should be registered")
|
||||
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notification")
|
||||
}
|
||||
|
||||
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
|
||||
}
|
||||
|
||||
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer1 := &MocPeer{PeerID: "testPeer1"}
|
||||
peer2 := &MocPeer{PeerID: "testPeer2"}
|
||||
mgr := NewManager(mockIface)
|
||||
defer mgr.Close()
|
||||
|
||||
cfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
cfg2 := lazyconn.PeerConfig{
|
||||
PublicKey: peer2.PeerID,
|
||||
PeerConnID: peer2.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
|
||||
Log: log.WithField("peer", "testPeer2"),
|
||||
}
|
||||
|
||||
err := mgr.MonitorPeerActivity(cfg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.MonitorPeerActivity(cfg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
|
||||
require.True(t, exists, "Peer1 listener should be found")
|
||||
bindListener1 := listener1.(*BindListener)
|
||||
|
||||
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
|
||||
require.True(t, exists, "Peer2 listener should be found")
|
||||
bindListener2 := listener2.(*BindListener)
|
||||
|
||||
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
|
||||
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
|
||||
_, err = conn1.Write([]byte{0x01})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
|
||||
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
|
||||
_, err = conn2.Write([]byte{0x02})
|
||||
require.NoError(t, err)
|
||||
|
||||
receivedPeers := make(map[peerid.ConnID]bool)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
receivedPeers[peerConnID] = true
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notifications")
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
|
||||
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestNewListener(t *testing.T) {
|
||||
peer := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
l, err := NewListener(MocWGIface{}, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
chanClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(chanClosed)
|
||||
l.ReadPackets()
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
l.Close()
|
||||
|
||||
select {
|
||||
case <-chanClosed:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
@@ -11,26 +11,27 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||
type Listener struct {
|
||||
// UDPListener uses UDP sockets for activity detection in kernel mode.
|
||||
type UDPListener struct {
|
||||
wgIface WgInterface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
conn *net.UDPConn
|
||||
endpoint *net.UDPAddr
|
||||
done sync.Mutex
|
||||
|
||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||
isClosed atomic.Bool
|
||||
}
|
||||
|
||||
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
d := &Listener{
|
||||
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
||||
func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
|
||||
d := &UDPListener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
}
|
||||
|
||||
conn, err := d.newConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
|
||||
return nil, fmt.Errorf("create UDP connection: %v", err)
|
||||
}
|
||||
d.conn = conn
|
||||
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
||||
@@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error
|
||||
if err := d.createEndpoint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.done.Lock()
|
||||
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
|
||||
cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Listener) ReadPackets() {
|
||||
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
||||
func (d *UDPListener) ReadPackets() {
|
||||
for {
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
if err != nil {
|
||||
@@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() {
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
if err := d.removeEndpoint(); err != nil {
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
|
||||
// Ignore close error as it may return "use of closed network connection" if already closed.
|
||||
_ = d.conn.Close()
|
||||
d.done.Unlock()
|
||||
}
|
||||
|
||||
func (d *Listener) Close() {
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *UDPListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||
d.isClosed.Store(true)
|
||||
|
||||
@@ -82,16 +87,12 @@ func (d *Listener) Close() {
|
||||
d.done.Lock()
|
||||
}
|
||||
|
||||
func (d *Listener) removeEndpoint() error {
|
||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (d *Listener) createEndpoint() error {
|
||||
func (d *UDPListener) createEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
||||
}
|
||||
|
||||
func (d *Listener) newConn() (*net.UDPConn, error) {
|
||||
func (d *UDPListener) newConn() (*net.UDPConn, error) {
|
||||
addr := &net.UDPAddr{
|
||||
Port: 0,
|
||||
IP: listenIP,
|
||||
110
client/internal/lazyconn/activity/listener_udp_test.go
Normal file
110
client/internal/lazyconn/activity/listener_udp_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestUDPListener_Creation(t *testing.T) {
|
||||
mockIface := &MocWGIface{}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewUDPListener(mockIface, cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, listener.conn)
|
||||
require.NotNil(t, listener.endpoint)
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPListener_ActivityDetection(t *testing.T) {
|
||||
mockIface := &MocWGIface{}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewUDPListener(mockIface, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
activityDetected := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(activityDetected)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-activityDetected:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity detection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPListener_Close(t *testing.T) {
|
||||
mockIface := &MocWGIface{}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewUDPListener(mockIface, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
|
||||
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
|
||||
}
|
||||
@@ -1,21 +1,32 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
// listener defines the contract for activity detection listeners.
|
||||
type listener interface {
|
||||
ReadPackets()
|
||||
Close()
|
||||
}
|
||||
|
||||
type WgInterface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -23,7 +34,7 @@ type Manager struct {
|
||||
|
||||
wgIface WgInterface
|
||||
|
||||
peers map[peerid.ConnID]*Listener
|
||||
peers map[peerid.ConnID]listener
|
||||
done chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
peers: make(map[peerid.ConnID]*Listener),
|
||||
peers: make(map[peerid.ConnID]listener),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
return m
|
||||
@@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
listener, err := NewListener(m.wgIface, peerCfg)
|
||||
listener, err := m.createListener(peerCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.peers[peerCfg.PeerConnID] = listener
|
||||
|
||||
m.peers[peerCfg.PeerConnID] = listener
|
||||
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is only used on Windows and JS platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// BindListener bypasses this by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
provider, ok := m.wgIface.(bindProvider)
|
||||
if !ok {
|
||||
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||
}
|
||||
|
||||
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -82,8 +115,8 @@ func (m *Manager) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
|
||||
listener.ReadPackets()
|
||||
func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
|
||||
l.ReadPackets()
|
||||
|
||||
m.mu.Lock()
|
||||
if _, ok := m.peers[peerConnID]; !ok {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
@@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error {
|
||||
|
||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Add this method to the Manager struct
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
|
||||
func (m MocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m MocWGIface) Address() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerListener is a test helper to access listeners
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
listener, exists := m.peers[peerConnID]
|
||||
return listener, exists
|
||||
l, exists := m.peers[peerConnID]
|
||||
return l, exists
|
||||
}
|
||||
|
||||
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
@@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
t.Fatalf("peer listener not found")
|
||||
}
|
||||
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
// Get the UDP listener's address for triggering
|
||||
udpListener, ok := listener.(*UDPListener)
|
||||
if !ok {
|
||||
t.Fatalf("expected UDPListener")
|
||||
}
|
||||
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
@@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
||||
listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
|
||||
udpListener, _ := listener.(*UDPListener)
|
||||
addr := udpListener.conn.LocalAddr().String()
|
||||
|
||||
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||
|
||||
@@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
t.Fatalf("peer listener for peer1 not found")
|
||||
}
|
||||
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
udpListener1, _ := listener.(*UDPListener)
|
||||
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
@@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
t.Fatalf("peer listener for peer2 not found")
|
||||
}
|
||||
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
udpListener2, _ := listener.(*UDPListener)
|
||||
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
@@ -14,5 +15,6 @@ type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
LastActivities() map[string]monotime.Time
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/store"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type rcvChan chan *types.EventFields
|
||||
@@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
|
||||
(event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||
@@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
@@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
conn.currentConnPriority = conntype.None
|
||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
|
||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOffeChan := make(chan struct{})
|
||||
onNewOfferChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteOffer(OfferAnswer{
|
||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOffeChan:
|
||||
case <-onNewOfferChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOffeChan := make(chan struct{})
|
||||
onNewOfferChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteAnswer(OfferAnswer{
|
||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOffeChan:
|
||||
case <-onNewOfferChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
|
||||
20
client/internal/peer/guard/env.go
Normal file
20
client/internal/peer/guard/env.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
||||
)
|
||||
|
||||
func GetICEMonitorPeriod() time.Duration {
|
||||
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
||||
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
return defaultCandidatesMonitorPeriod
|
||||
}
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ICEMonitor struct {
|
||||
@@ -25,16 +25,19 @@ type ICEMonitor struct {
|
||||
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig icemaker.Config
|
||||
tickerPeriod time.Duration
|
||||
|
||||
currentCandidatesAddress []string
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
||||
log.Debugf("prepare ICE monitor with period: %s", period)
|
||||
cm := &ICEMonitor{
|
||||
ReconnectCh: make(chan struct{}, 1),
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
iceConfig: config,
|
||||
tickerPeriod: period,
|
||||
}
|
||||
return cm
|
||||
}
|
||||
@@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
// Initial check to populate the candidates for later comparison
|
||||
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
||||
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cm.tickerPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
@@ -44,13 +44,19 @@ type OfferAnswer struct {
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
onNewOfferListeners []*OfferListener
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||
// and it will only keep the latest message if multiple offers are received in a short time
|
||||
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||
// and also to avoid processing old offers if multiple offers are received in a short time
|
||||
// the listener will always process the latest offer
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
l := NewOfferListener(offer)
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.relayListener = NewAsyncOfferListener(offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.iceListener = offer
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if err := h.sendAnswer(); err != nil {
|
||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||
continue
|
||||
}
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
h.log.Infof("stop listening for remote offers and answers")
|
||||
|
||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
||||
return oa.SessionID.String()
|
||||
}
|
||||
|
||||
type OfferListener struct {
|
||||
type AsyncOfferListener struct {
|
||||
fn callbackFunc
|
||||
running bool
|
||||
latest *OfferAnswer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||
return &OfferListener{
|
||||
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
||||
return &AsyncOfferListener{
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
||||
runChan <- struct{}{}
|
||||
}
|
||||
|
||||
hl := NewOfferListener(longRunningFn)
|
||||
hl := NewAsyncOfferListener(longRunningFn)
|
||||
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
|
||||
@@ -18,4 +18,5 @@ type WGIface interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
Address() wgaddr.Address
|
||||
RemoveEndpointAddress(key string) error
|
||||
}
|
||||
|
||||
@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
if w.agentConnecting {
|
||||
w.log.Debugf("agent connection is in progress, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if w.agent != nil {
|
||||
if w.agent != nil || w.agentConnecting {
|
||||
// backward compatibility with old clients that do not send session ID
|
||||
if remoteOfferAnswer.SessionID == nil {
|
||||
w.log.Debugf("agent already exists, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
w.log.Debugf("recreate ICE agent")
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
||||
}
|
||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.agent = agent
|
||||
w.agentDialerCancel = dialerCancel
|
||||
w.agentConnecting = true
|
||||
w.muxAgent.Unlock()
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
} else {
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
|
||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||
}
|
||||
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.muxAgent.Lock()
|
||||
w.agentConnecting = false
|
||||
w.lastSuccess = time.Now()
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
// todo: the potential problem is a race between the onConnectionStateChange
|
||||
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config := &Config{
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
WgPort: iface.DefaultWgPort,
|
||||
}
|
||||
|
||||
if _, err := config.apply(input); err != nil {
|
||||
|
||||
@@ -5,11 +5,14 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProfileDefaults(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err, "should create new config")
|
||||
|
||||
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
|
||||
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
|
||||
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
|
||||
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
|
||||
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
|
||||
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
|
||||
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
|
||||
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
|
||||
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
|
||||
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
|
||||
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
|
||||
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
|
||||
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
|
||||
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireguardPortZeroExplicit(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
// Create a new profile with explicit port 0 (random port)
|
||||
explicitZero := 0
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
WireguardPort: &explicitZero,
|
||||
})
|
||||
require.NoError(t, err, "should create config with explicit port 0")
|
||||
|
||||
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
|
||||
|
||||
// Verify it persists
|
||||
readConfig, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
|
||||
}
|
||||
|
||||
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no port specified uses default",
|
||||
wireguardPort: nil,
|
||||
expectedPort: iface.DefaultWgPort,
|
||||
description: "When user doesn't specify port, default to 51820",
|
||||
},
|
||||
{
|
||||
name: "explicit zero for random port",
|
||||
wireguardPort: func() *int { v := 0; return &v }(),
|
||||
expectedPort: 0,
|
||||
description: "When user explicitly sets 0, use 0 for random port",
|
||||
},
|
||||
{
|
||||
name: "explicit custom port",
|
||||
wireguardPort: func() *int { v := 52000; return &v }(),
|
||||
expectedPort: 52000,
|
||||
description: "When user sets custom port, use that port",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
WireguardPort: tt.wireguardPort,
|
||||
})
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -2,6 +2,8 @@ package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -15,6 +17,15 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCacheTTL = 20 * time.Second
|
||||
probeTimeout = 6 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCheckInProgress = errors.New("probe check is already in progress")
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
type ProbeResult struct {
|
||||
URI string
|
||||
@@ -22,8 +33,164 @@ type ProbeResult struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
type StunTurnProbe struct {
|
||||
cacheResults []ProbeResult
|
||||
cacheTimestamp time.Time
|
||||
cacheKey string
|
||||
cacheTTL time.Duration
|
||||
probeInProgress bool
|
||||
probeDone chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
|
||||
return &StunTurnProbe{
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
cacheKey := generateCacheKey(stuns, turns)
|
||||
|
||||
p.mu.Lock()
|
||||
if p.probeInProgress {
|
||||
doneChan := p.probeDone
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Context cancelled while waiting for probe results")
|
||||
return createErrorResults(stuns, turns)
|
||||
case <-doneChan:
|
||||
return p.getCachedResults(cacheKey, stuns, turns)
|
||||
}
|
||||
}
|
||||
|
||||
p.probeInProgress = true
|
||||
probeDone := make(chan struct{})
|
||||
p.probeDone = probeDone
|
||||
p.mu.Unlock()
|
||||
|
||||
p.doProbe(ctx, stuns, turns, cacheKey)
|
||||
close(probeDone)
|
||||
|
||||
return p.getCachedResults(cacheKey, stuns, turns)
|
||||
}
|
||||
|
||||
// ProbeAll probes all given servers asynchronously and returns the results
|
||||
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
cacheKey := generateCacheKey(stuns, turns)
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
if results := p.checkCache(cacheKey); results != nil {
|
||||
p.mu.Unlock()
|
||||
return results
|
||||
}
|
||||
|
||||
if p.probeInProgress {
|
||||
p.mu.Unlock()
|
||||
return createErrorResults(stuns, turns)
|
||||
}
|
||||
|
||||
p.probeInProgress = true
|
||||
probeDone := make(chan struct{})
|
||||
p.probeDone = probeDone
|
||||
log.Infof("started new probe for STUN, TURN servers")
|
||||
go func() {
|
||||
p.doProbe(ctx, stuns, turns, cacheKey)
|
||||
close(probeDone)
|
||||
}()
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
timer := time.NewTimer(1300 * time.Millisecond)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Context cancelled while waiting for probe results")
|
||||
return createErrorResults(stuns, turns)
|
||||
case <-probeDone:
|
||||
// when the probe is return fast, return the results right away
|
||||
return p.getCachedResults(cacheKey, stuns, turns)
|
||||
case <-timer.C:
|
||||
// if the probe takes longer than 1.3s, return error results to avoid blocking
|
||||
return createErrorResults(stuns, turns)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
|
||||
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
|
||||
age := time.Since(p.cacheTimestamp)
|
||||
if age < p.cacheTTL {
|
||||
results := append([]ProbeResult(nil), p.cacheResults...)
|
||||
log.Debugf("returning cached probe results (age: %v)", age)
|
||||
return results
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
|
||||
return append([]ProbeResult(nil), p.cacheResults...)
|
||||
}
|
||||
return createErrorResults(stuns, turns)
|
||||
}
|
||||
|
||||
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
|
||||
defer func() {
|
||||
p.mu.Lock()
|
||||
p.probeInProgress = false
|
||||
p.mu.Unlock()
|
||||
}()
|
||||
results := make([]ProbeResult, len(stuns)+len(turns))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range stuns {
|
||||
wg.Add(1)
|
||||
go func(idx int, stunURI *stun.URI) {
|
||||
defer wg.Done()
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
results[idx].URI = stunURI.String()
|
||||
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
|
||||
}(i, uri)
|
||||
}
|
||||
|
||||
stunOffset := len(stuns)
|
||||
for i, uri := range turns {
|
||||
wg.Add(1)
|
||||
go func(idx int, turnURI *stun.URI) {
|
||||
defer wg.Done()
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
results[idx].URI = turnURI.String()
|
||||
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
|
||||
}(stunOffset+i, uri)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
p.mu.Lock()
|
||||
p.cacheResults = results
|
||||
p.cacheTimestamp = time.Now()
|
||||
p.cacheKey = cacheKey
|
||||
p.mu.Unlock()
|
||||
|
||||
log.Debug("Stored new probe results in cache")
|
||||
}
|
||||
|
||||
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
|
||||
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
defer func() {
|
||||
if probeErr != nil {
|
||||
log.Debugf("stun probe error from %s: %s", uri, probeErr)
|
||||
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
}
|
||||
|
||||
// ProbeTURN tries allocating a session from the given TURN URI
|
||||
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||
defer func() {
|
||||
if probeErr != nil {
|
||||
log.Debugf("turn probe error from %s: %s", uri, probeErr)
|
||||
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
return relayConn.LocalAddr().String(), nil
|
||||
}
|
||||
|
||||
// ProbeAll probes all given servers asynchronously and returns the results
|
||||
func ProbeAll(
|
||||
ctx context.Context,
|
||||
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
|
||||
relays []*stun.URI,
|
||||
) []ProbeResult {
|
||||
results := make([]ProbeResult, len(relays))
|
||||
func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
|
||||
total := len(stuns) + len(turns)
|
||||
results := make([]ProbeResult, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range relays {
|
||||
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
go func(res *ProbeResult, stunURI *stun.URI) {
|
||||
defer wg.Done()
|
||||
res.URI = stunURI.String()
|
||||
res.Addr, res.Err = fn(ctx, stunURI)
|
||||
}(&results[i], uri)
|
||||
allURIs := append(append([]*stun.URI{}, stuns...), turns...)
|
||||
for i, uri := range allURIs {
|
||||
results[i] = ProbeResult{
|
||||
URI: uri.String(),
|
||||
Err: ErrCheckInProgress,
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
|
||||
h := sha256.New()
|
||||
for _, uri := range stuns {
|
||||
h.Write([]byte(uri.String()))
|
||||
}
|
||||
for _, uri := range turns {
|
||||
h.Write([]byte(uri.String()))
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -25,4 +26,5 @@ type HandlerParams struct {
|
||||
UseNewDNSRoute bool
|
||||
Firewall manager.Manager
|
||||
FakeIPManager *fakeip.Manager
|
||||
ForwarderPort *atomic.Uint32
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
forwarderPort *atomic.Uint32
|
||||
}
|
||||
|
||||
func New(params common.HandlerParams) *DnsInterceptor {
|
||||
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
|
||||
firewall: params.Firewall,
|
||||
fakeIPManager: params.FakeIPManager,
|
||||
interceptedDomains: make(domainMap),
|
||||
forwarderPort: params.ForwarderPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
@@ -54,6 +56,7 @@ type Manager interface {
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
SetFirewall(firewall.Manager) error
|
||||
SetDNSForwarderPort(port uint16)
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@@ -101,12 +104,13 @@ type DefaultManager struct {
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
fakeIPManager *fakeip.Manager
|
||||
dnsForwarderPort atomic.Uint32
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(config.Context)
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||
sysOps := systemops.New(config.WGInterface, notifier)
|
||||
|
||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||
@@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
disableServerRoutes: config.DisableServerRoutes,
|
||||
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
|
||||
}
|
||||
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
|
||||
|
||||
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||
dm.setupRefCounters(useNoop)
|
||||
@@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
|
||||
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
|
||||
m.dnsForwarderPort.Store(uint32(port))
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
m.stop()
|
||||
@@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
UseNewDNSRoute: m.useNewDNSRoute,
|
||||
Firewall: m.firewall,
|
||||
FakeIPManager: m.fakeIPManager,
|
||||
ForwarderPort: &m.dnsForwarderPort,
|
||||
}
|
||||
handler := client.HandlerFromRoute(params)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
|
||||
@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
|
||||
func (m *MockManager) SetDNSForwarderPort(port uint16) {
|
||||
}
|
||||
|
||||
// Stop mock implementation of Stop from Manager interface
|
||||
func (m *MockManager) Stop(stateManager *statemanager.Manager) {
|
||||
if m.StopFunc != nil {
|
||||
|
||||
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
|
||||
|
||||
package systemops
|
||||
|
||||
// FlushMarkedRoutes is a no-op on non-BSD platforms.
|
||||
func (r *SysOps) FlushMarkedRoutes() error {
|
||||
return nil
|
||||
}
|
||||
@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
sysops := NewSysOps(nil, nil)
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
sysOps := New(nil, nil)
|
||||
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
|
||||
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
|
||||
return sysops.refCounter.Flush()
|
||||
return sysOps.refCounter.Flush()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
|
||||
@@ -83,7 +83,7 @@ type SysOps struct {
|
||||
localSubnetsCacheTime time.Time
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
r := New(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err)
|
||||
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
r := New(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err)
|
||||
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
|
||||
assert.NoError(t, wgInterface.Close())
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
r := New(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
|
||||
@@ -7,19 +7,39 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||
)
|
||||
|
||||
var routeProtoFlag int
|
||||
|
||||
func init() {
|
||||
switch os.Getenv(envRouteProtoFlag) {
|
||||
case "2":
|
||||
routeProtoFlag = unix.RTF_PROTO2
|
||||
case "3":
|
||||
routeProtoFlag = unix.RTF_PROTO3
|
||||
default:
|
||||
routeProtoFlag = unix.RTF_PROTO1
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||
func (r *SysOps) FlushMarkedRoutes() error {
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
flushedCount := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
rtMsg, ok := msg.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
routeInfo, err := MsgToRoute(rtMsg)
|
||||
if err != nil {
|
||||
log.Debugf("Skipping route flush: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
|
||||
continue
|
||||
}
|
||||
|
||||
nexthop := Nexthop{
|
||||
IP: routeInfo.Gw,
|
||||
Intf: routeInfo.Interface,
|
||||
}
|
||||
|
||||
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
|
||||
continue
|
||||
}
|
||||
|
||||
flushedCount++
|
||||
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
|
||||
}
|
||||
|
||||
if flushedCount > 0 {
|
||||
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
}
|
||||
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
|
||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||
msg = &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP,
|
||||
Flags: unix.RTF_UP | routeProtoFlag,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
}
|
||||
|
||||
@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
|
||||
data, err := os.ReadFile(m.filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debug("state file does not exist")
|
||||
log.Debugf("state file %s does not exist", m.filePath)
|
||||
return nil, nil // nolint:nilnil
|
||||
}
|
||||
return nil, fmt.Errorf("read state file: %w", err)
|
||||
|
||||
59
client/internal/winregistry/volatile_windows.go
Normal file
59
client/internal/winregistry/volatile_windows.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package winregistry
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
advapi = syscall.NewLazyDLL("advapi32.dll")
|
||||
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
|
||||
)
|
||||
|
||||
const (
|
||||
// Registry key options
|
||||
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
|
||||
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
|
||||
|
||||
// Registry disposition values
|
||||
regCreatedNewKey = 0x1
|
||||
regOpenedExistingKey = 0x2
|
||||
)
|
||||
|
||||
// CreateVolatileKey creates a volatile registry key named path under open key root.
|
||||
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
|
||||
// The access parameter specifies the access rights for the key to be created.
|
||||
//
|
||||
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
|
||||
// This provides automatic cleanup without requiring manual registry maintenance.
|
||||
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
|
||||
pathPtr, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
var (
|
||||
handle syscall.Handle
|
||||
disposition uint32
|
||||
)
|
||||
|
||||
ret, _, _ := regCreateKeyExW.Call(
|
||||
uintptr(root),
|
||||
uintptr(unsafe.Pointer(pathPtr)),
|
||||
0, // reserved
|
||||
0, // class
|
||||
uintptr(regOptionVolatile), // options - volatile key
|
||||
uintptr(access), // desired access
|
||||
0, // security attributes
|
||||
uintptr(unsafe.Pointer(&handle)),
|
||||
uintptr(unsafe.Pointer(&disposition)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, false, syscall.Errno(ret)
|
||||
}
|
||||
|
||||
return registry.Key(handle), disposition == regOpenedExistingKey, nil
|
||||
}
|
||||
@@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.CustomDNSAddress = []byte{}
|
||||
}
|
||||
|
||||
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
interval := msg.DnsRouteInterval.AsDuration()
|
||||
config.DNSRouteInterval = &interval
|
||||
}
|
||||
|
||||
config.RosenpassEnabled = msg.RosenpassEnabled
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
@@ -1050,10 +1057,7 @@ func (s *Server) Status(
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
if msg.GetFullPeerStatus {
|
||||
if msg.ShouldRunProbes {
|
||||
s.runProbes()
|
||||
}
|
||||
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
@@ -1063,7 +1067,7 @@ func (s *Server) Status(
|
||||
return &statusResponse, nil
|
||||
}
|
||||
|
||||
func (s *Server) runProbes() {
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
}
|
||||
@@ -1074,7 +1078,7 @@ func (s *Server) runProbes() {
|
||||
}
|
||||
|
||||
if time.Since(s.lastProbe) > probeThreshold {
|
||||
if engine.RunHealthProbes() {
|
||||
if engine.RunHealthProbes(waitForProbeResult) {
|
||||
s.lastProbe = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
298
client/server/setconfig_test.go
Normal file
298
client/server/setconfig_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
|
||||
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
|
||||
func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
profName := "test-profile"
|
||||
|
||||
ic := profilemanager.ConfigInput{
|
||||
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||
ManagementURL: "https://api.netbird.io:443",
|
||||
}
|
||||
_, err = profilemanager.UpdateOrCreateConfig(ic)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
s := New(ctx, "console", "", false, false)
|
||||
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
disableAutoConnect := true
|
||||
networkMonitor := true
|
||||
disableClientRoutes := true
|
||||
disableServerRoutes := true
|
||||
disableDNS := true
|
||||
disableFirewall := true
|
||||
blockLANAccess := true
|
||||
disableNotifications := true
|
||||
lazyConnectionEnabled := true
|
||||
blockInbound := true
|
||||
mtu := int64(1280)
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: currUser.Username,
|
||||
ManagementUrl: "https://new-api.netbird.io:443",
|
||||
AdminURL: "https://new-admin.netbird.io",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
DisableAutoConnect: &disableAutoConnect,
|
||||
NetworkMonitor: &networkMonitor,
|
||||
DisableClientRoutes: &disableClientRoutes,
|
||||
DisableServerRoutes: &disableServerRoutes,
|
||||
DisableDns: &disableDNS,
|
||||
DisableFirewall: &disableFirewall,
|
||||
BlockLanAccess: &blockLANAccess,
|
||||
DisableNotifications: &disableNotifications,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||
DnsLabels: []string{"label1", "label2"},
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
}
|
||||
|
||||
_, err = s.SetConfig(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}
|
||||
cfgPath, err := profState.FilePath()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := profilemanager.GetConfig(cfgPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
|
||||
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
|
||||
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
|
||||
require.NotNil(t, cfg.NetworkMonitor)
|
||||
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
|
||||
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
|
||||
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
|
||||
require.Equal(t, disableDNS, cfg.DisableDNS)
|
||||
require.Equal(t, disableFirewall, cfg.DisableFirewall)
|
||||
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
|
||||
require.NotNil(t, cfg.DisableNotifications)
|
||||
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
|
||||
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
|
||||
require.Equal(t, blockInbound, cfg.BlockInbound)
|
||||
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
|
||||
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
|
||||
// IFaceBlackList contains defaults + extras
|
||||
require.Contains(t, cfg.IFaceBlackList, "eth1")
|
||||
require.Contains(t, cfg.IFaceBlackList, "eth2")
|
||||
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
||||
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
||||
require.Equal(t, uint16(mtu), cfg.MTU)
|
||||
|
||||
verifyAllFieldsCovered(t, req)
|
||||
}
|
||||
|
||||
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
|
||||
// If a new field is added to SetConfigRequest, this function will fail the test,
|
||||
// forcing the developer to update both the SetConfig handler and this test.
|
||||
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
t.Helper()
|
||||
|
||||
metadataFields := map[string]bool{
|
||||
"state": true, // protobuf internal
|
||||
"sizeCache": true, // protobuf internal
|
||||
"unknownFields": true, // protobuf internal
|
||||
"Username": true, // metadata
|
||||
"ProfileName": true, // metadata
|
||||
"CleanNATExternalIPs": true, // control flag for clearing
|
||||
"CleanDNSLabels": true, // control flag for clearing
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
typ := val.Type()
|
||||
|
||||
var unexpectedFields []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldName := field.Name
|
||||
|
||||
if metadataFields[fieldName] {
|
||||
continue
|
||||
}
|
||||
|
||||
if !expectedFields[fieldName] {
|
||||
unexpectedFields = append(unexpectedFields, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unexpectedFields) > 0 {
|
||||
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
|
||||
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
|
||||
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
// Map of CLI flag names to their corresponding SetConfigRequest field names.
|
||||
// This map must be updated when adding new config-related CLI flags.
|
||||
flagToField := map[string]string{
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
}
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
fieldsWithoutCLIFlags := map[string]bool{
|
||||
"DisableNotifications": true, // Only settable via UI
|
||||
}
|
||||
|
||||
// Get all SetConfigRequest fields to verify our map is complete.
|
||||
req := &proto.SetConfigRequest{}
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
typ := val.Type()
|
||||
|
||||
var unmappedFields []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldName := field.Name
|
||||
|
||||
// Skip protobuf internal fields and metadata fields.
|
||||
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
|
||||
continue
|
||||
}
|
||||
if fieldName == "Username" || fieldName == "ProfileName" {
|
||||
continue
|
||||
}
|
||||
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
|
||||
mappedToCLI := false
|
||||
for _, mappedField := range flagToField {
|
||||
if mappedField == fieldName {
|
||||
mappedToCLI = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
|
||||
|
||||
if !mappedToCLI && !hasNoCLIFlag {
|
||||
unmappedFields = append(unmappedFields, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unmappedFields) > 0 {
|
||||
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
|
||||
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
|
||||
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
|
||||
}
|
||||
|
||||
t.Log("All SetConfigRequest fields are properly documented")
|
||||
}
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||
}
|
||||
|
||||
// clean up any remaining routes independently of the state file
|
||||
if !nbnet.AdvancedRouting() {
|
||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -205,15 +206,18 @@ func mapPeers(
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := "P2P"
|
||||
connType := "-"
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
if isPeerConnected {
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
}
|
||||
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
||||
@@ -337,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
if relay.Error == probeRelay.ErrCheckInProgress.Error() {
|
||||
available = "Checking..."
|
||||
} else {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
}
|
||||
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"fyne.io/systray"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -297,6 +296,8 @@ type serviceClient struct {
|
||||
mExitNodeDeselectAll *systray.MenuItem
|
||||
logFile string
|
||||
wLoginURL fyne.Window
|
||||
|
||||
connectCancel context.CancelFunc
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -593,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Errorf("get active profile: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
@@ -611,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
||||
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("login to management URL with: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("login to management: %w", err)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin && openURL {
|
||||
err = s.handleSSOLogin(loginResp, conn)
|
||||
if err != nil {
|
||||
log.Errorf("handle SSO login failed: %v", err)
|
||||
return nil, err
|
||||
if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
|
||||
return nil, fmt.Errorf("SSO login: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
err := open.Run(loginResp.VerificationURIComplete)
|
||||
if err != nil {
|
||||
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
||||
return err
|
||||
func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
if err := openURL(loginResp.VerificationURIComplete); err != nil {
|
||||
return fmt.Errorf("open browser: %w", err)
|
||||
}
|
||||
|
||||
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
log.Errorf("waiting sso login failed with: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("wait for SSO login: %w", err)
|
||||
}
|
||||
|
||||
if resp.Email != "" {
|
||||
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
Email: resp.Email,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("failed to set profile state: %v", err)
|
||||
}); err != nil {
|
||||
log.Debugf("failed to set profile state: %v", err)
|
||||
} else {
|
||||
s.mProfile.refresh()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) menuUpClick() error {
|
||||
func (s *serviceClient) menuUpClick(ctx context.Context) error {
|
||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
systray.SetTemplateIcon(iconErrorMacOS, s.icError)
|
||||
log.Errorf("get client: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.login(true)
|
||||
_, err = s.login(ctx, true)
|
||||
if err != nil {
|
||||
log.Errorf("login failed with: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
log.Warnf("already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
return err
|
||||
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("start connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -698,24 +684,20 @@ func (s *serviceClient) menuDownClick() error {
|
||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
|
||||
log.Warnf("already down")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
return err
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("stop connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1354,7 +1336,13 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
||||
func (s *serviceClient) showLoginURL() {
|
||||
// It also starts a background goroutine that periodically checks if the client is already connected
|
||||
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
|
||||
// also cancelled when the window is closed.
|
||||
func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
|
||||
// create a cancellable context for the background check goroutine
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
||||
|
||||
@@ -1363,6 +1351,8 @@ func (s *serviceClient) showLoginURL() {
|
||||
s.wLoginURL.Resize(fyne.NewSize(400, 200))
|
||||
s.wLoginURL.SetIcon(resIcon)
|
||||
}
|
||||
// ensure goroutine is cancelled when the window is closed
|
||||
s.wLoginURL.SetOnClosed(func() { cancel() })
|
||||
// add a description label
|
||||
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
|
||||
|
||||
@@ -1374,7 +1364,7 @@ func (s *serviceClient) showLoginURL() {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.login(false)
|
||||
resp, err := s.login(ctx, false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to fetch login URL: %v", err)
|
||||
return
|
||||
@@ -1394,7 +1384,7 @@ func (s *serviceClient) showLoginURL() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
|
||||
_, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
|
||||
if err != nil {
|
||||
log.Errorf("Waiting sso login failed with: %v", err)
|
||||
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
|
||||
@@ -1402,7 +1392,7 @@ func (s *serviceClient) showLoginURL() {
|
||||
}
|
||||
|
||||
label.SetText("Re-authentication successful.\nReconnecting")
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return
|
||||
@@ -1415,7 +1405,7 @@ func (s *serviceClient) showLoginURL() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
_, err = conn.Up(ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
|
||||
log.Errorf("Reconnecting failed with: %v", err)
|
||||
@@ -1443,10 +1433,46 @@ func (s *serviceClient) showLoginURL() {
|
||||
)
|
||||
s.wLoginURL.SetContent(container.NewCenter(content))
|
||||
|
||||
// start a goroutine to check connection status and close the window if connected
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
if s.wLoginURL != nil {
|
||||
s.wLoginURL.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.wLoginURL.Show()
|
||||
|
||||
// return cancel func so callers can stop the background goroutine if desired
|
||||
return cancel
|
||||
}
|
||||
|
||||
func openURL(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
|
||||
var err error
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
|
||||
return "", err
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get post-up status: %v", err)
|
||||
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
|
||||
func (h *eventHandler) handleConnectClick() {
|
||||
h.client.mUp.Disable()
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
h.client.connectCancel()
|
||||
}
|
||||
|
||||
connectCtx, connectCancel := context.WithCancel(h.client.ctx)
|
||||
h.client.connectCancel = connectCancel
|
||||
|
||||
go func() {
|
||||
defer h.client.mUp.Enable()
|
||||
if err := h.client.menuUpClick(); err != nil {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
defer connectCancel()
|
||||
|
||||
if err := h.client.menuUpClick(connectCtx); err != nil {
|
||||
st, ok := status.FromError(err)
|
||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||
log.Debugf("connect operation cancelled by user")
|
||||
} else {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||
log.Errorf("connect failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.client.updateStatus(); err != nil {
|
||||
log.Debugf("failed to update status after connect: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleDisconnectClick() {
|
||||
h.client.mDown.Disable()
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
log.Debugf("cancelling ongoing connect operation")
|
||||
h.client.connectCancel()
|
||||
h.client.connectCancel = nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer h.client.mDown.Enable()
|
||||
if err := h.client.menuDownClick(); err != nil {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon"))
|
||||
st, ok := status.FromError(err)
|
||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||
log.Errorf("disconnect failed: %v", err)
|
||||
} else {
|
||||
log.Debugf("disconnect cancelled or already disconnecting")
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.client.updateStatus(); err != nil {
|
||||
log.Debugf("failed to update status after disconnect: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
h.client.getSrvConfig()
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -387,6 +387,7 @@ type subItem struct {
|
||||
type profileMenu struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
serviceClient *serviceClient
|
||||
profileManager *profilemanager.ProfileManager
|
||||
eventHandler *eventHandler
|
||||
profileMenuItem *systray.MenuItem
|
||||
@@ -396,7 +397,7 @@ type profileMenu struct {
|
||||
logoutSubItem *subItem
|
||||
profilesState []Profile
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
upClickCallback func(context.Context) error
|
||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||
loadSettingsCallback func()
|
||||
app fyne.App
|
||||
@@ -404,12 +405,13 @@ type profileMenu struct {
|
||||
|
||||
type newProfileMenuArgs struct {
|
||||
ctx context.Context
|
||||
serviceClient *serviceClient
|
||||
profileManager *profilemanager.ProfileManager
|
||||
eventHandler *eventHandler
|
||||
profileMenuItem *systray.MenuItem
|
||||
emailMenuItem *systray.MenuItem
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
upClickCallback func(context.Context) error
|
||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||
loadSettingsCallback func()
|
||||
app fyne.App
|
||||
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
|
||||
func newProfileMenu(args newProfileMenuArgs) *profileMenu {
|
||||
p := profileMenu{
|
||||
ctx: args.ctx,
|
||||
serviceClient: args.serviceClient,
|
||||
profileManager: args.profileManager,
|
||||
eventHandler: args.eventHandler,
|
||||
profileMenuItem: args.profileMenuItem,
|
||||
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.upClickCallback(); err != nil {
|
||||
if p.serviceClient.connectCancel != nil {
|
||||
p.serviceClient.connectCancel()
|
||||
}
|
||||
|
||||
connectCtx, connectCancel := context.WithCancel(p.ctx)
|
||||
p.serviceClient.connectCancel = connectCancel
|
||||
|
||||
if err := p.upClickCallback(connectCtx); err != nil {
|
||||
log.Errorf("failed to handle up click after switching profile: %v", err)
|
||||
}
|
||||
|
||||
connectCancel()
|
||||
|
||||
p.refresh()
|
||||
p.loadSettingsCallback()
|
||||
}
|
||||
|
||||
@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||
if requiresCredSSP {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
} else {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -19,18 +21,34 @@ const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
|
||||
rdpDialTimeout = 15 * time.Second
|
||||
|
||||
GeneralErrorCode = 1
|
||||
WSAETimedOut = 10060
|
||||
WSAEConnRefused = 10061
|
||||
WSAEConnAborted = 10053
|
||||
WSAEConnReset = 10054
|
||||
WSAEGenericError = 10050
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathErr struct {
|
||||
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func errorToWSACode(err error) int16 {
|
||||
if err == nil {
|
||||
return WSAEGenericError
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return WSAEConnAborted
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return WSAEConnReset
|
||||
}
|
||||
return WSAEGenericError
|
||||
}
|
||||
|
||||
func newWSAError(err error) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
WSALastError: errorToWSACode(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
HTTPStatusCode: statusCode,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
@@ -11,11 +12,17 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||
protocolSSL = 0x00000001
|
||||
protocolHybridEx = 0x00000008
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||
const minResponseLength = 19
|
||||
|
||||
if len(x224Response) < minResponseLength {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
// Per X.224 specification:
|
||||
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
if x224Response[11] == 0x02 {
|
||||
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||
|
||||
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||
return hasNLA, flags, true
|
||||
}
|
||||
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||
if detected {
|
||||
if requiresCredSSP {
|
||||
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||
} else {
|
||||
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user