Compare commits

...

3 Commits

Author SHA1 Message Date
Pascal Fischer
3dd34c920e [management] add context cancel monitoring (#5879)
(cherry picked from commit c5623307cc)
2026-04-14 13:03:30 +02:00
Vlad
7989bbff3b [management] revert ctx dependency in get account with backpressure (#5878)
(cherry picked from commit 7f666b8022)
2026-04-14 13:02:56 +02:00
Viktor Liu
4eed459f27 [client] Fix DNS resolution with userspace WireGuard and kernel firewall (#5873) 2026-04-13 16:23:57 +02:00
8 changed files with 179 additions and 76 deletions

View File

@@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
}

View File

@@ -0,0 +1,37 @@
package common
import (
"net/netip"
"sync/atomic"
)
// PacketHook stores a registered hook for a specific IP:port.
type PacketHook struct {
IP netip.Addr
Port uint16
Fn func([]byte) bool
}
// HookMatches checks if a packet's destination matches the hook and invokes it.
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.IP == dstIP && h.Port == dport {
return h.Fn(packetData)
}
return false
}
// SetHook atomically stores a hook, handling nil removal.
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
if hook == nil {
ptr.Store(nil)
return
}
ptr.Store(&PacketHook{
IP: ip,
Port: dPort,
Fn: hook,
})
}

View File

@@ -142,15 +142,8 @@ type Manager struct {
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
udpHookOut atomic.Pointer[common.PacketHook]
tcpHookOut atomic.Pointer[common.PacketHook]
}
// decoder for packages
@@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
}
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
return false
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
// filterInbound implements filtering logic for incoming packets.
@@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
common.SetHook(&m.udpHookOut, ip, dPort, hook)
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(8000), h.Port)
assert.True(t, h.Fn(nil))
assert.True(t, called)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(53), h.Port)
assert.True(t, h.Fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)

View File

@@ -0,0 +1,90 @@
package uspfilter
import (
"encoding/binary"
"net/netip"
"sync/atomic"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9
ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17
ipProtoTCP = 6
ipv4FragOffMask = 0x1fff
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
dstPortOffset = 2
)
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
// It is installed on the WireGuard interface when the userspace bind is active
// but a full firewall filter (Manager) is not needed because a native kernel
// firewall (nftables/iptables) handles packet filtering.
type HooksFilter struct {
udpHook atomic.Pointer[common.PacketHook]
tcpHook atomic.Pointer[common.PacketHook]
}
var _ device.PacketFilter = (*HooksFilter)(nil)
// FilterOutbound checks outbound packets for DNS hook matches.
// Only IPv4 packets matching the registered hook IP:port are intercepted.
// IPv6 and non-IP packets pass through unconditionally.
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
if len(packetData) < ipv4HeaderMinLen {
return false
}
// Only process IPv4 packets, let everything else pass through.
if packetData[0]>>4 != 4 {
return false
}
ihl := int(packetData[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
return false
}
// Skip non-first fragments: they don't carry L4 headers.
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
if flagsAndOffset&ipv4FragOffMask != 0 {
return false
}
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
if !ok {
return false
}
proto := packetData[ipv4ProtoOffset]
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
switch proto {
case ipProtoUDP:
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
case ipProtoTCP:
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
default:
return false
}
}
// FilterInbound allows all inbound packets (native firewall handles filtering).
func (f *HooksFilter) FilterInbound([]byte, int) bool {
return false
}
// SetUDPPacketHook registers the UDP packet hook.
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.udpHook, ip, dPort, hook)
}
// SetTCPPacketHook registers the TCP packet hook.
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.tcpHook, ip, dPort, hook)
}

View File

@@ -63,20 +63,11 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountRequestCh <- req
select {
case <-ctx.Done():
return nil, ctx.Err()
case ac.getAccountRequestCh <- req:
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-req.ResultChan:
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {

View File

@@ -1017,10 +1017,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
// GetCustomDomainsCounts returns the total and validated custom domain counts.
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
var total, validated int64
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
return 0, 0, err
}
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
return 0, 0, err
}
return total, validated, nil
@@ -4442,7 +4442,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4461,7 +4461,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr
// GetAllProxyAccessTokens retrieves all proxy access tokens.
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4477,7 +4477,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc
// SaveProxyAccessToken saves a proxy access token to the database.
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
if result := s.db.Create(token); result.Error != nil {
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
}
return nil
@@ -4485,7 +4485,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA
// RevokeProxyAccessToken revokes a proxy access token by its ID.
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
if result.Error != nil {
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
}
@@ -4499,7 +4499,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
result := s.db.Model(&types.ProxyAccessToken{}).
Where(idQueryCondition, tokenID).
Update("last_used", time.Now().UTC())
if result.Error != nil {
@@ -5168,7 +5168,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5184,7 +5184,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength
// GetServicesByCluster returns all services for the given proxy cluster.
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5294,7 +5294,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
var logs []*accesslogs.AccessLogEntry
var totalCount int64
baseQuery := s.db.WithContext(ctx).
baseQuery := s.db.
Model(&accesslogs.AccessLogEntry{}).
Where(accountIDCondition, accountID)
@@ -5305,7 +5305,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
}
query := s.db.WithContext(ctx).
query := s.db.
Where(accountIDCondition, accountID)
query = s.applyAccessLogFilters(query, filter)
@@ -5342,7 +5342,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
// DeleteOldAccessLogs deletes all access logs older than the specified time
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
result := s.db.WithContext(ctx).
result := s.db.
Where("timestamp < ?", olderThan).
Delete(&accesslogs.AccessLogEntry{})
@@ -5431,7 +5431,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
result := s.db.Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
@@ -5443,7 +5443,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
result := s.db.WithContext(ctx).
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", now)
@@ -5462,7 +5462,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.WithContext(ctx).Save(p).Error; err != nil {
if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
}
@@ -5475,7 +5475,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
result := s.db.
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
@@ -5544,7 +5544,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
AnyTrue bool
}
err := s.db.WithContext(ctx).
err := s.db.
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
@@ -5568,7 +5568,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
result := s.db.
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})

View File

@@ -183,7 +183,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
w := WrapResponseWriter(rw)
handlerDone := make(chan struct{})
context.AfterFunc(ctx, func() {
select {
case <-handlerDone:
default:
log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)",
r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx))
}
})
h.ServeHTTP(w, r.WithContext(ctx))
close(handlerDone)
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {