Compare commits

..

3 Commits

Author SHA1 Message Date
Zoltán Papp
2734a3356e [client] Fix nil SessionID panic and force ICE teardown on relay-only transition
Fix nil pointer dereference in signalOfferAnswer when SessionID is nil
(relay-only offers). Close stale ICE agent immediately when remote peer
stops sending ICE credentials to avoid traffic black-hole during the
ICE disconnect timeout.
2026-04-07 17:12:42 +02:00
Zoltán Papp
c6d660df4e [client] Dynamically suppress ICE based on remote peer's offer credentials
Track whether the remote peer includes ICE credentials in its
offers/answers. When remote stops sending ICE credentials, skip
ICE listener dispatch, suppress ICE credentials in responses, and
exclude ICE from the guard connectivity check. When remote resumes
sending ICE credentials, re-enable all ICE behavior.
2026-04-07 16:07:41 +02:00
Zoltán Papp
721251460c [client] Suppress ICE signaling and periodic offers in force-relay mode
When NB_FORCE_RELAY is enabled, skip WorkerICE creation entirely,
suppress ICE credentials in offer/answer messages, disable the
periodic ICE candidate monitor, and fix isConnectedOnAllWay to
only check relay status so the guard stops sending unnecessary offers.
2026-04-07 14:08:26 +02:00
80 changed files with 797 additions and 2817 deletions

View File

@@ -199,11 +199,9 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
}
@@ -219,7 +217,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = false
cmd.Println("netbird up")
}
@@ -267,14 +264,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())

View File

@@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
@@ -202,7 +201,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
@@ -237,7 +236,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
return fmt.Errorf("receive expose event: %w", err)
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)

View File

@@ -286,22 +286,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"

View File

@@ -36,7 +36,6 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
@@ -44,7 +43,6 @@ const (
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
@@ -389,14 +387,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
chain string
table string
@@ -406,7 +396,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -981,81 +970,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -169,14 +169,6 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error

View File

@@ -346,22 +346,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"

View File

@@ -36,7 +36,6 @@ const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
@@ -1854,130 +1853,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -140,17 +140,6 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -605,8 +594,6 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -726,9 +713,6 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -911,21 +895,38 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
m.mutex.RLock()
defer m.mutex.RUnlock()
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1277,6 +1278,12 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1335,30 +1342,65 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -187,52 +186,81 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -502,12 +530,39 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -537,7 +592,8 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
manager.SetUDPPacketHook(
hookID := manager.AddUDPPacketHook(
false,
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -545,6 +601,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -144,8 +144,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -421,7 +421,6 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -467,22 +466,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,7 +18,9 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
drop bool
udpHook func([]byte) bool
}
// ID returns the rule id

View File

@@ -399,17 +399,21 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,

View File

@@ -15,17 +15,14 @@ type PacketFilter interface {
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one UDP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one TCP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
}
// FilteredDevice to override Read or Write of packets

View File

@@ -34,28 +34,18 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// SetUDPPacketHook mocks base method.
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
}
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
}
// SetTCPPacketHook mocks base method.
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
}
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
@@ -85,3 +75,17 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}

View File

@@ -0,0 +1,87 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -111,7 +111,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -121,7 +120,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")

View File

@@ -73,9 +73,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
return nil
}
w.response = m
if m.MsgHdr.Truncated {
w.SetMeta("truncated", "true")
}
return w.ResponseWriter.WriteMsg(m)
}
@@ -198,14 +195,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
requestID := resutil.GenerateRequestID()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
@@ -268,9 +261,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
cw.response.Len(), meta, time.Since(startTime))
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
// TestLocalResolver_Stop tests cleanup on GracefullyStop
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})

View File

@@ -90,11 +90,6 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}
// SetFirewall mock implementation of SetFirewall from Server interface
func (m *MockServer) SetFirewall(Firewall) {
// Mock implementation - no-op
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op

View File

@@ -104,23 +104,3 @@ func (r *responseWriter) TsigTimersOnly(bool) {
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
var srcIP net.IP
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
srcIP = ipv4.(*layers.IPv4).SrcIP
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
srcIP = ipv6.(*layers.IPv6).SrcIP
}
var srcPort int
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
srcPort = int(udp.(*layers.UDP).SrcPort)
}
if srcIP == nil {
return nil
}
return &net.UDPAddr{IP: srcIP, Port: srcPort}
}

View File

@@ -58,7 +58,6 @@ type Server interface {
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
SetFirewall(Firewall)
}
type nsGroupsByDomain struct {
@@ -152,7 +151,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
dnsService = newServiceViaListener(config.WgInterface, addrPort)
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
@@ -187,16 +186,11 @@ func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
hostsDnsList []netip.AddrPort,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
return ds
}
@@ -380,17 +374,6 @@ func (s *DefaultServer) DnsIP() netip.Addr {
return s.service.RuntimeIP()
}
// SetFirewall sets the firewall used for DNS port DNAT rules.
// This must be called before Initialize when using the listener-based service,
// because the firewall is typically not available at construction time.
func (s *DefaultServer) SetFirewall(fw Firewall) {
if svc, ok := s.service.(*serviceViaListener); ok {
svc.listenerFlagLock.Lock()
svc.firewall = fw
svc.listenerFlagLock.Unlock()
}
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
@@ -412,12 +395,8 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains)
}
func (s *DefaultServer) disableDNS() (retErr error) {
defer func() {
if err := s.service.Stop(); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
}
}()
func (s *DefaultServer) disableDNS() error {
defer s.service.Stop()
if s.isUsingNoopHostManager() {
return nil

View File

@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}

View File

@@ -4,25 +4,15 @@ import (
"net/netip"
"github.com/miekg/dns"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
DefaultPort = 53
)
// Firewall provides DNAT capabilities for DNS port redirection.
// This is used when the DNS server cannot bind port 53 directly
// and needs firewall rules to redirect traffic.
type Firewall interface {
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
}
type service interface {
Listen() error
Stop() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int

View File

@@ -10,13 +10,9 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
@@ -35,33 +31,25 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
tcpServer *dns.Server
listenIP netip.Addr
listenPort uint16
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
firewall Firewall
tcpDNATConfigured bool
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
firewall: fw,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
tcpServer: &dns.Server{
Net: "tcp",
Handler: mux,
},
}
return s
@@ -82,86 +70,43 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
s.server.Addr = addr
s.tcpServer.Addr = addr
log.Debugf("starting dns on %s (UDP + TCP)", addr)
s.listenerIsRunning = true
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
if err := s.server.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
}
s.setListenerStatus(true)
defer s.setListenerStatus(false)
s.listenerFlagLock.Lock()
unexpected := s.listenerIsRunning
s.listenerIsRunning = false
s.listenerFlagLock.Unlock()
if unexpected {
if err := s.tcpServer.Shutdown(); err != nil {
log.Debugf("failed to shutdown DNS TCP server: %v", err)
}
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
}
}()
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
// a DNAT rule because eBPF only handles UDP.
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
} else {
s.tcpDNATConfigured = true
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
}
}
return nil
}
func (s *serviceViaListener) Stop() error {
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return nil
return
}
s.listenerIsRunning = false
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var merr *multierror.Error
if err := s.server.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
}
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
}
if s.tcpDNATConfigured && s.firewall != nil {
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
s.tcpDNATConfigured = false
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
if s.ebpfService != nil {
if err := s.ebpfService.FreeDNSFwd(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -188,6 +133,12 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}
// evalListenAddress figure out the listen address for the DNS server
// first check the 53 port availability on WG interface or lo, if not success
@@ -236,28 +187,18 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrPort := netip.AddrPortFrom(ip, uint16(port))
udpAddr := net.UDPAddrFromAddrPort(addrPort)
udpLn, err := net.ListenUDP("udp", udpAddr)
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
return false
}
if err := udpLn.Close(); err != nil {
log.Debugf("close UDP probe listener: %s", err)
}
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
err = probeListener.Close()
if err != nil {
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
return false
log.Errorf("got an error closing the probe listener, error: %s", err)
}
if err := tcpLn.Close(); err != nil {
log.Debugf("close TCP probe listener: %s", err)
}
return true
}

View File

@@ -1,86 +0,0 @@
package dns
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("192.0.2.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// Create a service using a custom address to avoid needing root
svc := newServiceViaListener(nil, nil, nil)
svc.dnsMux.Handle(".", handler)
// Bind both transports up front to avoid TOCTOU races.
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Skip("cannot bind to 127.0.0.153, skipping")
}
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
t.Skip("cannot bind TCP on same port, skipping")
}
addr := fmt.Sprintf("%s:%d", customIP, port)
svc.server.PacketConn = udpConn
svc.tcpServer.Listener = tcpLn
svc.listenIP = customIP
svc.listenPort = port
go func() {
if err := svc.server.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := svc.tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
svc.listenerIsRunning = true
defer func() {
require.NoError(t, svc.Stop())
}()
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
// Test UDP query
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
udpResp, _, err := udpClient.Exchange(q, addr)
require.NoError(t, err, "UDP query should succeed")
require.NotNil(t, udpResp)
require.NotEmpty(t, udpResp.Answer)
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
// Test TCP query
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
tcpResp, _, err := tcpClient.Exchange(q, addr)
require.NoError(t, err, "TCP query should succeed")
require.NotNil(t, tcpResp)
require.NotEmpty(t, tcpResp.Answer)
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
}

View File

@@ -1,7 +1,6 @@
package dns
import (
"errors"
"fmt"
"net/netip"
"sync"
@@ -11,7 +10,6 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -20,8 +18,7 @@ type ServiceViaMemory struct {
dnsMux *dns.ServeMux
runtimeIP netip.Addr
runtimePort int
tcpDNS *tcpDNSServer
tcpHookSet bool
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
@@ -31,13 +28,14 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
return &ServiceViaMemory{
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
return s
}
func (s *ServiceViaMemory) Listen() error {
@@ -48,8 +46,10 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
if err := s.filterDNSTraffic(); err != nil {
return fmt.Errorf("filter dns traffic: %w", err)
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return fmt.Errorf("filter dns traffice: %w", err)
}
s.listenerIsRunning = true
@@ -57,29 +57,19 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
func (s *ServiceViaMemory) Stop() error {
func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return nil
return
}
filter := s.wgInterface.GetFilter()
if filter != nil {
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
if s.tcpHookSet {
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
return nil
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
@@ -98,18 +88,10 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() error {
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return errors.New("DNS filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
@@ -118,16 +100,12 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
if udpLayer == nil {
return true
}
udp, ok := udpLayer.(*layers.UDP)
if !ok {
return true
}
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
@@ -135,30 +113,13 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
return true
}
dev := s.wgInterface.GetDevice()
if dev == nil {
return true
}
writer := &responseWriter{
remote: remoteAddrFromPacket(packet),
writer := responseWriter{
packet: packet,
device: dev.Device,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(writer, msg)
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.InjectPacket(packetData)
return true
}
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
s.tcpHookSet = true
}
return nil
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
}

View File

@@ -1,444 +0,0 @@
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
dnsTCPReceiveWindow = 8192
dnsTCPMaxInFlight = 16
dnsTCPIdleTimeout = 30 * time.Second
dnsTCPReadTimeout = 5 * time.Second
)
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
// It is started lazily when a truncated DNS response is detected and shuts down
// after a period of inactivity to conserve resources.
type tcpDNSServer struct {
mu sync.Mutex
s *stack.Stack
ep *dnsEndpoint
mux *dns.ServeMux
tunDev tun.Device
ip netip.Addr
port uint16
mtu uint16
running bool
closed bool
timerID uint64
timer *time.Timer
}
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
return &tcpDNSServer{
mux: mux,
tunDev: tunDev,
ip: ip,
port: port,
mtu: mtu,
}
}
// InjectPacket ensures the stack is running and delivers a raw IP packet into
// the gvisor stack for TCP processing. Combining both operations under a single
// lock prevents a race where the idle timer could stop the stack between
// start and delivery.
func (t *tcpDNSServer) InjectPacket(payload []byte) {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return
}
if !t.running {
if err := t.startLocked(); err != nil {
log.Errorf("failed to start TCP DNS stack: %v", err)
return
}
t.running = true
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
}
t.resetTimerLocked()
ep := t.ep
if ep == nil || ep.dispatcher == nil {
return
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
// Stop tears down the gvisor stack and releases resources permanently.
// After Stop, InjectPacket becomes a no-op.
func (t *tcpDNSServer) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
t.closed = true
}
func (t *tcpDNSServer) startLocked() error {
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
HandleLocal: false,
})
nicID := tcpip.NICID(1)
ep := &dnsEndpoint{
tunDev: t.tunDev,
}
ep.mtu.Store(uint32(t.mtu))
if err := s.CreateNIC(nicID, ep); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
PrefixLen: 32,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("add protocol address: %s", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set spoofing: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create default subnet: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
})
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
t.handleTCPDNS(r)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
t.s = s
t.ep = ep
return nil
}
func (t *tcpDNSServer) stopLocked() {
if !t.running {
return
}
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
if t.s != nil {
t.s.Close()
t.s.Wait()
t.s = nil
}
t.ep = nil
t.running = false
log.Debugf("TCP DNS stack stopped")
}
func (t *tcpDNSServer) resetTimerLocked() {
if t.timer != nil {
t.timer.Stop()
}
t.timerID++
id := t.timerID
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
t.mu.Lock()
defer t.mu.Unlock()
// Only stop if this timer is still the active one.
// A racing InjectPacket may have replaced it.
if t.timerID != id {
return
}
t.stopLocked()
})
}
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
id := r.ID()
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
defer func() {
if err := conn.Close(); err != nil {
log.Tracef("TCP DNS: close conn: %v", err)
}
}()
// Reset idle timer on activity
t.mu.Lock()
t.resetTimerLocked()
t.mu.Unlock()
localAddr := &net.TCPAddr{
IP: id.LocalAddress.AsSlice(),
Port: int(id.LocalPort),
}
remoteAddr := &net.TCPAddr{
IP: id.RemoteAddress.AsSlice(),
Port: int(id.RemotePort),
}
for {
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
break
}
msg, err := readTCPDNSMessage(conn)
if err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
}
break
}
writer := &tcpResponseWriter{
conn: conn,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
t.mux.ServeDNS(writer, msg)
}
}
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
type dnsEndpoint struct {
dispatcher stack.NetworkDispatcher
tunDev tun.Device
mtu atomic.Uint32
}
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
func (e *dnsEndpoint) Wait() { /* no async work */ }
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
const tunPacketOffset = 40
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := data.AsSlice()
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
buf = append(buf, raw...)
data.Release()
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
continue
}
written++
}
return written, nil
}
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
type tcpResponseWriter struct {
conn *gonet.TCPConn
localAddr net.Addr
remoteAddr net.Addr
}
func (w *tcpResponseWriter) LocalAddr() net.Addr {
return w.localAddr
}
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
}
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
data, err := msg.Pack()
if err != nil {
return fmt.Errorf("pack: %w", err)
}
// DNS TCP: 2-byte length prefix + message
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err = w.conn.Write(buf); err != nil {
return err
}
return nil
}
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err := w.conn.Write(buf); err != nil {
return 0, err
}
return len(data), nil
}
func (w *tcpResponseWriter) Close() error {
return w.conn.Close()
}
func (w *tcpResponseWriter) TsigStatus() error { return nil }
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
// DNS over TCP uses a 2-byte length prefix
lenBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, fmt.Errorf("read length: %w", err)
}
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
if msgLen == 0 || msgLen > 65535 {
return nil, fmt.Errorf("invalid message length: %d", msgLen)
}
msgBuf := make([]byte, msgLen)
if _, err := io.ReadFull(conn, msgBuf); err != nil {
return nil, fmt.Errorf("read message: %w", err)
}
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
return nil, fmt.Errorf("unpack: %w", err)
}
return msg, nil
}
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
// Supports both IPv4 and IPv6.
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
if len(pkt) == 0 {
return netip.AddrPort{}
}
srcIP, transportOffset := srcIPFromPacket(pkt)
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
return netip.AddrPort{}
}
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
}
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
switch header.IPVersion(pkt) {
case 4:
return srcIPv4(pkt)
case 6:
return srcIPv6(pkt)
default:
return netip.Addr{}, 0
}
}
func srcIPv4(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv4MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv4(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, int(hdr.HeaderLength())
}
func srcIPv6(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv6MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv6(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, header.IPv6MinimumSize
}

View File

@@ -41,61 +41,10 @@ const (
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
const (
protoUDP = "udp"
protoTCP = "tcp"
)
type dnsProtocolKey struct{}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
}
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
func dnsProtocolFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
return v
}
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
func setUpstreamProtocol(ctx context.Context, protocol string) {
if ctx == nil {
return
}
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
r.protocol = protocol
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
@@ -189,16 +138,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// Propagate inbound protocol so upstream exchange can use TCP directly
// when the request came in over TCP.
ctx := u.ctx
if addr := w.RemoteAddr(); addr != nil {
network := addr.Network()
ctx = contextWithDNSProtocol(ctx, network)
resutil.SetMeta(w, "protocol", network)
}
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
ok, failures := u.tryUpstreamServers(w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
@@ -213,7 +153,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -228,7 +168,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
@@ -238,17 +178,15 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
@@ -265,7 +203,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
return nil
}
@@ -282,13 +220,10 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
}
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
@@ -493,42 +428,13 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
return int(opt.UDPSize())
}
return dns.MinMsgSize
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
return rm, t, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than our read buffer.
// Note: the query could be sent out on an interface that is not ours,
// but higher MTU settings could break truncation handling.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
client.UDPSize = maxUDPPayload
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = uint16(currentMTU - (60 + 8))
var (
rm *dns.Msg
@@ -547,32 +453,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
}
if rm == nil || !rm.MsgHdr.Truncated {
setUpstreamProtocol(ctx, protoUDP)
return rm, t, nil
}
// TODO: if the upstream's truncated UDP response already contains more
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
tcpClient := *client
tcpClient.Net = protoTCP
client.Net = "tcp"
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
return rm, t, nil
}
@@ -580,46 +479,18 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
// If request came in over TCP, go straight to TCP upstream
if dnsProtocolFromContext(ctx) == protoTCP {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
return rm, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than what we can read over UDP.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, nil
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
setUpstreamProtocol(ctx, protoUDP)
return reply, nil
}
@@ -640,7 +511,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
}
}
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)

View File

@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
Timeout: timeout,
}
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -475,298 +475,3 @@ func TestFormatFailures(t *testing.T) {
})
}
}
func TestDNSProtocolContext(t *testing.T) {
t.Run("roundtrip udp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
})
t.Run("roundtrip tcp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
})
t.Run("missing returns empty", func(t *testing.T) {
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
})
}
func TestExchangeWithFallback_TCPContext(t *testing.T) {
// Start a local DNS server that responds on TCP only
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpServer := &dns.Server{
Addr: "127.0.0.1:0",
Net: "tcp",
Handler: tcpHandler,
}
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer.Listener = tcpLn
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// With TCP context, should connect directly via TCP without trying UDP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
}
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
// UDP handler returns a truncated response to trigger TCP retry.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// TCP handler returns the full answer.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.3"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{
PacketConn: udpPC,
Net: "udp",
Handler: udpHandler,
}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := udpServer.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
}()
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
assert.False(t, rm.Truncated, "TCP response should not be truncated")
}
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
// Start only a TCP server (no UDP). With TCP context it should succeed.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.2"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// TCP context: should skip UDP entirely and go directly to TCP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
// Without TCP context, trying to reach a TCP-only server via UDP should fail
ctx2 := context.Background()
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
assert.Error(t, err, "should fail when no UDP server and no TCP context")
}
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
t.Cleanup(func() { _ = udpServer.Shutdown() })
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
"upstream should see capped EDNS0, not the client's 4096")
}
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
// When the client advertises a large EDNS0 (4096) and the upstream
// truncates, the TCP response should NOT be truncated since the full
// answer fits within the client's original buffer.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
// Add enough records to exceed MTU but fit within 4096
for i := range 20 {
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
})
}
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
go func() { _ = tcpServer.ActivateAndServe() }()
t.Cleanup(func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
})
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
// Client with large buffer: should get all records without truncation
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
// Client with small buffer: should get truncated response
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r2.SetEdns0(512, false)
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
require.NoError(t, err)
require.NotNil(t, rm2)
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}

View File

@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
@@ -263,28 +263,20 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
f.handleDNSQuery(logger, w, query, startTime)
}

View File

@@ -46,7 +46,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
@@ -211,10 +210,9 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
@@ -261,27 +259,26 @@ func NewEngine(
mobileDep MobileDependency,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -524,11 +521,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return err
}
// Inject firewall into DNS server now that it's available.
// The DNS server is created before the firewall because the route manager
// depends on the DNS server, and the firewall depends on the wg interface.
e.dnsServer.SetFirewall(e.firewall)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
@@ -540,13 +532,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -569,7 +554,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -1550,13 +1535,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
MetricsRecorder: e.clientMetrics,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
MetricsRecorder: e.clientMetrics,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1713,12 +1697,6 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1822,7 +1800,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
@@ -1859,11 +1837,6 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// IsBlockInbound returns whether inbound connections are blocked.
func (e *Engine) IsBlockInbound() bool {
return e.config.BlockInbound
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics

View File

@@ -22,7 +22,6 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -46,7 +45,6 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder
}
@@ -89,17 +87,16 @@ type ConnConfig struct {
}
type Conn struct {
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
@@ -148,20 +145,19 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
portForwardManager: services.PortForwardManager,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
return conn, nil
@@ -185,17 +181,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
forceRelay := IsForceRelayed()
if !forceRelay {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
if !forceRelay {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
}
@@ -251,7 +250,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()
if conn.workerICE != nil {
conn.workerICE.Close()
}
if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn()
@@ -294,7 +295,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
if conn.workerICE != nil {
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -721,14 +724,19 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
// For force-relayed connections (JS or NB_FORCE_RELAY): only relay status matters
if IsForceRelayed() {
if !conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
return false
}
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
// For non-forced platforms: check ICE connection status only if remote peer supports ICE
if conn.handshaker.RemoteICESupported() {
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
}
// If relay is supported with peer, it must also be connected

View File

@@ -10,7 +10,7 @@ const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
func IsForceRelayed() bool {
if runtime.GOOS == "js" {
return true
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw
}
func (w *SRWatcher) Start() {
func (w *SRWatcher) Start(disableICEMonitor bool) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
if !disableICEMonitor {
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
}
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
@@ -59,6 +60,10 @@ type Handshaker struct {
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -66,7 +71,7 @@ type Handshaker struct {
}
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
return &Handshaker{
h := &Handshaker{
log: log,
config: config,
signaler: signaler,
@@ -76,6 +81,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
}
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -97,11 +109,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer)
}
@@ -117,11 +131,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer)
}
case <-ctx.Done():
@@ -183,15 +199,18 @@ func (h *Handshaker) sendAnswer() error {
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
}
if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
}
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
@@ -200,3 +219,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
return answer
}
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.IceCredentials.UFrag != "" && offer.IceCredentials.Pwd != ""
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
var sessionIDBytes []byte
if offerAnswer.SessionID != nil {
var err error
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
@@ -62,9 +61,6 @@ type WorkerICE struct {
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -218,8 +214,6 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
@@ -376,93 +370,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}
mapping := pfManager.GetMapping()
if mapping == nil {
return
}
w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}
// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}
for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}
return candidate, nil
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -504,10 +411,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}

View File

@@ -1,26 +0,0 @@
package portforward
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
)
func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
return false
}
return disabled
}

View File

@@ -1,250 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
)
const (
defaultMappingTTL = 2 * time.Hour
renewalInterval = defaultMappingTTL / 2
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
}
type Manager struct {
cancel context.CancelFunc
mapping *Mapping
mappingLock sync.Mutex
wgPort uint16
done chan struct{}
stopCtx chan context.Context
// protect exported functions
mu sync.Mutex
}
func NewManager() *Manager {
return &Manager{
stopCtx: make(chan context.Context, 1),
}
}
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
m.mu.Lock()
if m.cancel != nil {
m.mu.Unlock()
return
}
if isDisabledByEnv() {
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
m.mu.Unlock()
return
}
if wgPort == 0 {
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
m.mu.Unlock()
return
}
m.wgPort = wgPort
m.done = make(chan struct{})
defer close(m.done)
ctx, m.cancel = context.WithCancel(ctx)
m.mu.Unlock()
gateway, mapping, err := m.setup(ctx)
if err != nil {
log.Errorf("failed to setup NAT port mapping: %v", err)
return
}
m.mappingLock.Lock()
m.mapping = mapping
m.mappingLock.Unlock()
m.renewLoop(ctx, gateway)
select {
case cleanupCtx := <-m.stopCtx:
// block the Start while cleaned up gracefully
m.cleanup(cleanupCtx, gateway)
default:
// return Start immediately and cleanup in background
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
defer cleanupCancel()
m.cleanup(cleanupCtx, gateway)
}()
}
}
// GetMapping returns the current mapping if ready, nil otherwise
func (m *Manager) GetMapping() *Mapping {
m.mappingLock.Lock()
defer m.mappingLock.Unlock()
if m.mapping == nil {
return nil
}
mapping := *m.mapping
return &mapping
}
// GracefullyStop cancels the manager and attempts to delete the port mapping.
// After GracefullyStop returns, the manager cannot be restarted.
func (m *Manager) GracefullyStop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
m.startTearDown(ctx)
m.cancel()
m.cancel = nil
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := nat.DiscoverGateway(discoverCtx)
if err != nil {
log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err)
return nil, nil, err
}
log.Infof("discovered NAT gateway: %s", gateway.Type())
mapping, err := m.createMapping(ctx, gateway)
if err != nil {
log.Warnf("failed to create port mapping: %v", err)
return nil, nil, err
}
return gateway, mapping, nil
}
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL)
if err != nil {
return nil, err
}
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
Protocol: "udp",
InternalPort: m.wgPort,
ExternalPort: uint16(externalPort),
ExternalIP: externalIP,
NATType: gateway.Type(),
}
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
m.wgPort, externalPort, gateway.Type(), externalIP)
return mapping, nil
}
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT) {
ticker := time.NewTicker(renewalInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
}
}
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL)
if err != nil {
return fmt.Errorf("add port mapping: %w", err)
}
if uint16(externalPort) != m.mapping.ExternalPort {
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
m.mappingLock.Lock()
m.mapping.ExternalPort = uint16(externalPort)
m.mappingLock.Unlock()
}
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
return nil
}
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
m.mappingLock.Lock()
mapping := m.mapping
m.mapping = nil
m.mappingLock.Unlock()
if mapping == nil {
return
}
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
log.Warnf("delete port mapping on stop: %v", err)
return
}
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
}
func (m *Manager) startTearDown(ctx context.Context) {
select {
case m.stopCtx <- ctx:
default:
}
}

View File

@@ -1,36 +0,0 @@
package portforward
import (
"context"
"net"
)
// Mapping represents port mapping information.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
}
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
type Manager struct{}
// NewManager returns a stub manager for js/wasm builds.
func NewManager() *Manager {
return &Manager{}
}
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
func (m *Manager) Start(context.Context, uint16) {
// no NAT traversal in wasm
}
// GracefullyStop is a no-op on js/wasm.
func (m *Manager) GracefullyStop(context.Context) error { return nil }
// GetMapping always returns nil on js/wasm.
func (m *Manager) GetMapping() *Mapping {
return nil
}

View File

@@ -1,159 +0,0 @@
//go:build !js
package portforward
import (
"context"
"net"
"testing"
"time"
"github.com/libp2p/go-nat"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNAT struct {
natType string
deviceAddr net.IP
externalAddr net.IP
internalAddr net.IP
mappings map[int]int
addMappingErr error
deleteMappingErr error
}
func newMockNAT() *mockNAT {
return &mockNAT{
natType: "Mock-NAT",
deviceAddr: net.ParseIP("192.168.1.1"),
externalAddr: net.ParseIP("203.0.113.50"),
internalAddr: net.ParseIP("192.168.1.100"),
mappings: make(map[int]int),
}
}
func (m *mockNAT) Type() string {
return m.natType
}
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
return m.deviceAddr, nil
}
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
return m.externalAddr, nil
}
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
return m.internalAddr, nil
}
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
if m.addMappingErr != nil {
return 0, m.addMappingErr
}
externalPort := internalPort
m.mappings[internalPort] = externalPort
return externalPort, nil
}
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if m.deleteMappingErr != nil {
return m.deleteMappingErr
}
delete(m.mappings, internalPort)
return nil
}
func TestManager_CreateMapping(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, "udp", mapping.Protocol)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, uint16(51820), mapping.ExternalPort)
assert.Equal(t, "Mock-NAT", mapping.NATType)
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
}
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
m := NewManager()
assert.Nil(t, m.GetMapping())
}
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
mapping := m.GetMapping()
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
// Mutating the returned copy should not affect the manager's mapping.
mapping.ExternalPort = 9999
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
}
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
gateway := newMockNAT()
// Seed the mock so we can verify deletion.
gateway.mappings[51820] = 51820
m.cleanup(context.Background(), gateway)
_, exists := gateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted from gateway")
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
}
func TestManager_Cleanup_NilMapping(t *testing.T) {
m := NewManager()
gateway := newMockNAT()
// Should not panic or call gateway.
m.cleanup(context.Background(), gateway)
}
func TestState_Cleanup(t *testing.T) {
origDiscover := discoverGateway
defer func() { discoverGateway = origDiscover }()
mockGateway := newMockNAT()
mockGateway.mappings[51820] = 51820
discoverGateway = func(ctx context.Context) (nat.NAT, error) {
return mockGateway, nil
}
state := &State{
Protocol: "udp",
InternalPort: 51820,
}
err := state.Cleanup()
assert.NoError(t, err)
_, exists := mockGateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted after cleanup")
}
func TestState_Name(t *testing.T) {
state := &State{}
assert.Equal(t, "port_forward_state", state.Name())
}

View File

@@ -1,50 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
)
// discoverGateway is the function used for NAT gateway discovery.
// It can be replaced in tests to avoid real network operations.
var discoverGateway = nat.DiscoverGateway
// State is persisted only for crash recovery cleanup
type State struct {
InternalPort uint16 `json:"internal_port,omitempty"`
Protocol string `json:"protocol,omitempty"`
}
func (s *State) Name() string {
return "port_forward_state"
}
// Cleanup implements statemanager.CleanableState for crash recovery
func (s *State) Cleanup() error {
if s.InternalPort == 0 {
return nil
}
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
gateway, err := discoverGateway(ctx)
if err != nil {
// Discovery failure is not an error - gateway may not exist
log.Debugf("cleanup: no gateway found: %v", err)
return nil
}
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
return fmt.Errorf("delete port mapping: %w", err)
}
return nil
}

View File

@@ -53,6 +53,7 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()

View File

@@ -161,11 +161,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -1359,10 +1359,6 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
if engine.IsBlockInbound() {
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")

View File

@@ -9,11 +9,6 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
// Note: portforward.State is intentionally NOT registered here to avoid blocking startup
// for up to 10 seconds during NAT gateway discovery when no gateway is present.
// The gateway reference cannot be persisted across restarts, so cleanup requires re-discovery.
// Port forward cleanup is handled by the Manager during normal operation instead.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -11,11 +11,6 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
// Note: portforward.State is intentionally NOT registered here to avoid blocking startup
// for up to 10 seconds during NAT gateway discovery when no gateway is present.
// The gateway reference cannot be persisted across restarts, so cleanup requires re-discovery.
// Port forward cleanup is handled by the Manager during normal operation instead.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := session.RawCommand() != ""
hasCommand := len(session.Command()) > 0
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
}
if hasCommand {
if err := serverSession.Run(session.RawCommand()); err != nil {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}

View File

@@ -1,7 +1,6 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -246,191 +245,6 @@ func TestSSHProxy_Connect(t *testing.T) {
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte

View File

@@ -284,21 +284,19 @@ func (s *Server) closeListener(ln net.Listener) {
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
sshServer := s.sshServer
if sshServer == nil {
s.mu.Unlock()
defer s.mu.Unlock()
if s.sshServer == nil {
return nil
}
s.sshServer = nil
s.listener = nil
s.mu.Unlock()
// Close outside the lock: session handlers need s.mu for unregisterSession.
if err := sshServer.Close(); err != nil {
if err := s.sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
s.mu.Lock()
s.sshServer = nil
s.listener = nil
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.connections)
@@ -309,7 +307,6 @@ func (s *Server) Stop() error {
}
}
maps.Clear(s.remoteForwardListeners)
s.mu.Unlock()
return nil
}

View File

@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
}
ptyReq, winCh, isPty := session.Pty()
hasCommand := session.RawCommand() != ""
hasCommand := len(session.Command()) > 0
if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)

View File

@@ -153,9 +153,6 @@ func networkAddresses() ([]NetworkAddress, error) {
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}

View File

@@ -24,10 +24,9 @@ import (
// Initial state for the debug collection
type debugInitialState struct {
wasDown bool
needsRestoreUp bool
logLevel proto.LogLevel
isLevelTrace bool
wasDown bool
logLevel proto.LogLevel
isLevelTrace bool
}
// Debug collection parameters
@@ -372,51 +371,46 @@ func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient,
state *debugInitialState,
enablePersistence bool,
) {
) error {
if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Warnf("failed to bring service up: %v", err)
} else {
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
return fmt.Errorf("bring service up: %v", err)
}
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
log.Warnf("failed to set log level to TRACE: %v", err)
} else {
log.Info("Log level set to TRACE for debug")
return fmt.Errorf("set log level to TRACE: %v", err)
}
log.Info("Log level set to TRACE for debug")
}
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Warnf("failed to bring service down: %v", err)
} else {
state.needsRestoreUp = !state.wasDown
time.Sleep(time.Second)
return fmt.Errorf("bring service down: %v", err)
}
time.Sleep(time.Second)
if enablePersistence {
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
log.Warnf("failed to enable sync response persistence: %v", err)
} else {
log.Info("Sync response persistence enabled for debug")
return fmt.Errorf("enable sync response persistence: %v", err)
}
log.Info("Sync response persistence enabled for debug")
}
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Warnf("failed to bring service back up: %v", err)
} else {
state.needsRestoreUp = false
time.Sleep(time.Second * 3)
return fmt.Errorf("bring service back up: %v", err)
}
time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err)
}
return nil
}
func (s *serviceClient) collectDebugData(
@@ -430,7 +424,9 @@ func (s *serviceClient) collectDebugData(
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
s.configureServiceForDebug(conn, state, params.enablePersistence)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return err
}
wg.Wait()
progress.progressBar.Hide()
@@ -486,17 +482,9 @@ func (s *serviceClient) createDebugBundleFromCollection(
// Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
if state.needsRestoreUp {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Warnf("failed to restore up state: %v", err)
} else {
log.Info("Service state restored to up")
}
}
if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Warnf("failed to restore down state: %v", err)
log.Errorf("Failed to restore down state: %v", err)
} else {
log.Info("Service state restored to down")
}
@@ -504,7 +492,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
log.Warnf("failed to restore log level: %v", err)
log.Errorf("Failed to restore log level: %v", err)
} else {
log.Info("Log level restored to original setting")
}

View File

@@ -29,7 +29,6 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
@@ -524,7 +523,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn listener.Conn)
var relayAcceptFn func(conn net.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
@@ -564,7 +563,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
@@ -586,9 +585,15 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
return
}
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr)
acceptFn(conn)
}

4
go.mod
View File

@@ -63,7 +63,6 @@ require (
github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
@@ -201,12 +200,10 @@ require (
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
@@ -216,7 +213,6 @@ require (
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/libdns/libdns v0.2.2 // indirect

8
go.sum
View File

@@ -281,8 +281,6 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY=
github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -293,8 +291,6 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@@ -332,8 +328,6 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0=
github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -352,8 +346,6 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=

View File

@@ -85,8 +85,8 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}

View File

@@ -1119,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}

View File

@@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}

View File

@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
}
// GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
}
// GetIdentityProvider mocks base method.

View File

@@ -61,10 +61,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}

View File

@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}

View File

@@ -46,7 +46,7 @@ type MockAccountManager struct {
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, groupName, accountID, userID)
return am.GetGroupByNameFunc(ctx, accountID, groupName)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
}

View File

@@ -2080,8 +2080,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
mode, listen_port, port_auto_assigned, source, source_peer, terminated
pass_host_header, rewrite_redirects, session_private_key, session_public_key
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
@@ -2098,7 +2097,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
err := row.Scan(
&s.ID,
&s.AccountID,
@@ -2114,12 +2112,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&s.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
&mode,
&s.ListenPort,
&s.PortAutoAssigned,
&source,
&sourcePeer,
&s.Terminated,
)
if err != nil {
return nil, err
@@ -2151,15 +2143,6 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if sessionPublicKey.Valid {
s.SessionPublicKey = sessionPublicKey.String
}
if mode.Valid {
s.Mode = mode.String
}
if source.Valid {
s.Source = source.String
}
if sourcePeer.Valid {
s.SourcePeer = sourcePeer.String
}
s.Targets = []*rpservice.Target{}
return &s, nil

View File

@@ -121,7 +121,7 @@ type Store interface {
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error

View File

@@ -165,6 +165,34 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -1361,34 +1389,6 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx)
}
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetCustomDomain mocks base method.
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
m.ctrl.T.Helper()
@@ -1466,18 +1466,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou
}
// GetGroupByName mocks base method.
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) {
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID)
ret0, _ := ret[0].(*types2.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID)
}
// GetGroupsByIDs mocks base method.
@@ -1974,21 +1974,6 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID)
}
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
@@ -2376,6 +2361,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
}
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// IsPrimaryAccount mocks base method.
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
m.ctrl.T.Helper()

View File

@@ -1,13 +1,11 @@
package server
import (
"context"
"fmt"
"time"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/shared/relay/messages/address"
@@ -15,12 +13,6 @@ import (
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
)
const (
// handshakeTimeout bounds how long a connection may remain in the
// pre-authentication handshake phase before being closed.
handshakeTimeout = 10 * time.Second
)
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
@@ -66,7 +58,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
}
type handshake struct {
conn listener.Conn
conn net.Conn
validator Validator
preparedMsg *preparedMsg
@@ -74,9 +66,9 @@ type handshake struct {
peerID *messages.PeerID
}
func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(ctx, buf)
n, err := h.conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
@@ -111,7 +103,7 @@ func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, err
return peerID, nil
}
func (h *handshake) handshakeResponse(ctx context.Context) error {
func (h *handshake) handshakeResponse() error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
@@ -119,7 +111,7 @@ func (h *handshake) handshakeResponse(ctx context.Context) error {
responseMsg = h.preparedMsg.responseHelloMsg
}
if _, err := h.conn.Write(ctx, responseMsg); err != nil {
if _, err := h.conn.Write(responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}

View File

@@ -1,14 +0,0 @@
package listener
import (
"context"
"net"
)
// Conn is the relay connection contract implemented by WS and QUIC transports.
type Conn interface {
Read(ctx context.Context, b []byte) (n int, err error)
Write(ctx context.Context, b []byte) (n int, err error)
RemoteAddr() net.Addr
Close() error
}

View File

@@ -0,0 +1,14 @@
package listener
import (
"context"
"net"
"github.com/netbirdio/netbird/relay/protocol"
)
type Listener interface {
Listen(func(conn net.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}

View File

@@ -3,26 +3,33 @@ package quic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/quic-go/quic-go"
)
type Conn struct {
session *quic.Conn
closed bool
closedMu sync.Mutex
session *quic.Conn
closed bool
closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConn(session *quic.Conn) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{
session: session,
session: session,
ctx: ctx,
ctxCancel: cancel,
}
}
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(ctx)
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
@@ -31,17 +38,33 @@ func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
return n, nil
}
func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
func (c *Conn) Write(b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err)
}
return len(b), nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {
@@ -51,6 +74,8 @@ func (c *Conn) Close() error {
c.closed = true
c.closedMu.Unlock()
c.ctxCancel() // Cancel the context
sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr
}

View File

@@ -5,12 +5,12 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
nbRelay "github.com/netbirdio/netbird/shared/relay"
)
@@ -25,7 +25,7 @@ type Listener struct {
listener *quic.Listener
}
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,

View File

@@ -18,21 +18,25 @@ const (
type Conn struct {
*websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr
closed bool
closedMu sync.Mutex
ctx context.Context
}
func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
return &Conn{
Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr,
ctx: context.Background(),
}
}
func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
t, r, err := c.Reader(ctx)
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx)
if err != nil {
return 0, c.ioErrHandling(err)
}
@@ -52,18 +56,34 @@ func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
// Write writes a binary message with the given payload.
// It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout.
func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
func (c *Conn) Write(b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
defer ctxCancel()
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}
func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true

View File

@@ -7,13 +7,11 @@ import (
"fmt"
"net"
"net/http"
"time"
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay"
)
@@ -29,19 +27,18 @@ type Listener struct {
TLSConfig *tls.Config
server *http.Server
acceptFn func(conn relaylistener.Conn)
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc(URLPath, l.onAccept)
l.server = &http.Server{
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
ReadHeaderTimeout: 5 * time.Second,
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
}
log.Infof("WS server listening address: %s", l.Address)
@@ -96,9 +93,18 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return
}
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
log.Errorf("failed to close ws connection: %s", err)
}
return
}
log.Infof("WS client connected from: %s", rAddr)
conn := NewConn(wsConn, rAddr)
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/store"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
@@ -27,14 +26,11 @@ type Peer struct {
metrics *metrics.Metrics
log *log.Entry
id messages.PeerID
conn listener.Conn
conn net.Conn
connMu sync.RWMutex
store *store.Store
notifier *store.PeerNotifier
ctx context.Context
ctxCancel context.CancelFunc
peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
@@ -42,17 +38,14 @@ type Peer struct {
}
// NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
ctx, cancel := context.WithCancel(context.Background())
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
p := &Peer{
metrics: metrics,
log: log.WithField("peer_id", id.String()),
id: id,
conn: conn,
store: store,
notifier: notifier,
ctx: ctx,
ctxCancel: cancel,
metrics: metrics,
log: log.WithField("peer_id", id.String()),
id: id,
conn: conn,
store: store,
notifier: notifier,
}
return p
@@ -64,7 +57,6 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, s
func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() {
p.ctxCancel()
p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
@@ -72,7 +64,8 @@ func (p *Peer) Work() {
}
}()
ctx := p.ctx
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx)
@@ -80,7 +73,7 @@ func (p *Peer) Work() {
buf := make([]byte, bufferSize)
for {
n, err := p.conn.Read(ctx, buf)
n, err := p.conn.Read(buf)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err)
@@ -138,10 +131,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
}
// Write writes data to the connection
func (p *Peer) Write(ctx context.Context, b []byte) (int, error) {
func (p *Peer) Write(b []byte) (int, error) {
p.connMu.RLock()
defer p.connMu.RUnlock()
return p.conn.Write(ctx, b)
return p.conn.Write(b)
}
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
@@ -154,7 +147,6 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String())
}
p.ctxCancel()
if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err)
}
@@ -164,7 +156,6 @@ func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock()
p.ctxCancel()
if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err)
}
@@ -179,15 +170,26 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
_, err := p.conn.Write(ctx, buf)
return err
writeDone := make(chan struct{})
var err error
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-writeDone:
return err
}
}
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
for {
select {
case <-hc.HealthCheck:
_, err := p.Write(ctx, messages.MarshalHealthcheck())
_, err := p.Write(messages.MarshalHealthcheck())
if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err)
return
@@ -226,12 +228,12 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return
}
n, err := dp.Write(dp.ctx, msg)
n, err := dp.Write(msg)
if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String())
return
}
p.metrics.TransferBytesSent.Add(p.ctx, int64(n))
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
}
func (p *Peer) handleSubscribePeerState(msg []byte) {
@@ -274,7 +276,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
}
for n, msg := range msgs {
if _, err := p.Write(p.ctx, msg); err != nil {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
@@ -291,7 +293,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
}
for n, msg := range msgs {
if _, err := p.Write(p.ctx, msg); err != nil {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"net"
"net/url"
"sync"
"time"
@@ -12,20 +13,11 @@ import (
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
)
type Listener interface {
Listen(func(conn listener.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}
type Config struct {
Meter metric.Meter
ExposedAddress string
@@ -117,7 +109,7 @@ func NewRelay(config Config) (*Relay, error) {
}
// Accept start to handle a new peer connection
func (r *Relay) Accept(conn listener.Conn) {
func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now()
r.closeMu.RLock()
defer r.closeMu.RUnlock()
@@ -125,15 +117,12 @@ func (r *Relay) Accept(conn listener.Conn) {
return
}
hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout)
defer hsCancel()
h := handshake{
conn: conn,
validator: r.validator,
preparedMsg: r.preparedMsg,
}
peerID, err := h.handshakeReceive(hsCtx)
peerID, err := h.handshakeReceive()
if err != nil {
if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr())
@@ -165,7 +154,7 @@ func (r *Relay) Accept(conn listener.Conn) {
r.metrics.PeerDisconnected(peer.String())
}()
if err := h.handshakeResponse(hsCtx); err != nil {
if err := h.handshakeResponse(); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close()
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
"net"
"net/url"
"sync"
@@ -30,7 +31,7 @@ type ListenerConfig struct {
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
relay *Relay
listeners []Listener
listeners []listener.Listener
listenerMux sync.Mutex
}
@@ -55,7 +56,7 @@ func NewServer(config Config) (*Server, error) {
}
return &Server{
relay: relay,
listeners: make([]Listener, 0, 2),
listeners: make([]listener.Listener, 0, 2),
}, nil
}
@@ -85,7 +86,7 @@ func (r *Server) Listen(cfg ListenerConfig) error {
wg := sync.WaitGroup{}
for _, l := range r.listeners {
wg.Add(1)
go func(listener Listener) {
go func(listener listener.Listener) {
defer wg.Done()
errChan <- listener.Listen(r.relay.Accept)
}(l)
@@ -138,6 +139,6 @@ func (r *Server) InstanceURL() url.URL {
// RelayAccept returns the relay's Accept function for handling incoming connections.
// This allows external HTTP handlers to route connections to the relay without
// starting the relay's own listeners.
func (r *Server) RelayAccept() func(conn listener.Conn) {
func (r *Server) RelayAccept() func(conn net.Conn) {
return r.relay.Accept
}