Compare commits

..

7 Commits

Author SHA1 Message Date
bcmmbaga
feb8e90ae1 Evaluate all applied posture checks on source peers only
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 23:28:34 +03:00
bcmmbaga
076d6d8a87 Evaluate all applied posture checks once
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 22:12:47 +03:00
bcmmbaga
c8c25221bd Apply policy posture checks on peer
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 21:49:28 +03:00
Pascal Fischer
fbce8bb511 [management] remove ids from policy creation api (#2997) 2024-12-27 14:13:36 +01:00
Bethuel Mmbaga
445b626dc8 [management] Add missing group usage checks for network resources and routes access control (#3117)
* Prevent deletion of groups linked to routes access control groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Prevent deletion of groups linked to network resource

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 14:39:34 +03:00
Viktor Liu
b3c87cb5d1 [client] Fix inbound tracking in userspace firewall (#3111)
* Don't create state for inbound SYN

* Allow final ack in some cases

* Relax state machine test a little
2024-12-26 00:51:27 +01:00
Viktor Liu
0dbaddc7be [client] Don't fail debug if log file is console (#3103) 2024-12-24 15:05:23 +01:00
27 changed files with 260 additions and 718 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testIn
ignore_words_list: erro,clienta,hastable,iif,groupd
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@@ -215,11 +215,6 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -100,9 +100,6 @@ type Manager interface {
// Flush the changes to firewall controller
Flush() error
// CollectStats returns the statistics of the firewall manager
CollectStats() []*FlowStats
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -1,107 +0,0 @@
package manager
import (
"encoding/json"
"net"
"slices"
"strconv"
"sync/atomic"
"time"
)
const (
DirectionInbound Direction = 0
DirectionOutbound Direction = 1
)
type Direction uint8
func (d Direction) String() string {
switch d {
case DirectionInbound:
return "inbound"
case DirectionOutbound:
return "outbound"
default:
return "unknown"
}
}
// FlowStats tracks statistics for an individual connection
type FlowStats struct {
StartTime time.Time
LastSeen time.Time
BytesIn atomic.Uint64
BytesOut atomic.Uint64
PacketsIn atomic.Uint64
PacketsOut atomic.Uint64
Protocol uint8
Direction Direction
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
}
func (f *FlowStats) Clone() *FlowStats {
flowCopy := FlowStats{
StartTime: f.StartTime,
LastSeen: f.LastSeen,
Protocol: f.Protocol,
Direction: f.Direction,
SourceIP: slices.Clone(f.SourceIP),
DestIP: slices.Clone(f.DestIP),
SourcePort: f.SourcePort,
DestPort: f.DestPort,
}
flowCopy.BytesIn.Store(f.BytesIn.Load())
flowCopy.BytesOut.Store(f.BytesOut.Load())
flowCopy.PacketsIn.Store(f.PacketsIn.Load())
flowCopy.PacketsOut.Store(f.PacketsOut.Load())
return &flowCopy
}
// MarshalJSON implements json.Marshaler interface
func (f *FlowStats) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
StartTime time.Time
LastSeen time.Time
BytesIn uint64
BytesOut uint64
PacketsIn uint64
PacketsOut uint64
Protocol Protocol
Direction string
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
}{
StartTime: f.StartTime,
LastSeen: f.LastSeen,
BytesIn: f.BytesIn.Load(),
BytesOut: f.BytesOut.Load(),
PacketsIn: f.PacketsIn.Load(),
PacketsOut: f.PacketsOut.Load(),
Protocol: protoFromInt(f.Protocol),
Direction: f.Direction.String(),
SourceIP: f.SourceIP,
DestIP: f.DestIP,
SourcePort: f.SourcePort,
DestPort: f.DestPort,
})
}
func protoFromInt(p uint8) Protocol {
switch p {
case 6:
return ProtocolTCP
case 17:
return ProtocolUDP
case 1:
return ProtocolICMP
default:
return Protocol(strconv.Itoa(int(p)))
}
}

View File

@@ -323,11 +323,6 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -17,17 +17,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if m.nativeFirewall != nil {

View File

@@ -29,17 +29,17 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if !isWindowsFirewallReachable() {

View File

@@ -10,7 +10,6 @@ import (
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP
DestIP net.IP
SourcePort uint16

View File

@@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) {
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
// Generate different IPs
@@ -79,17 +79,17 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, nil)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, nil)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
// Generate different IPs
@@ -104,11 +104,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, nil)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), nil)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})

View File

@@ -6,8 +6,6 @@ import (
"time"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
@@ -41,11 +39,10 @@ type ICMPTracker struct {
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
stats *Stats
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker {
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
@@ -56,7 +53,6 @@ func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker {
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
@@ -64,7 +60,7 @@ func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker {
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, packetData []byte) {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
@@ -87,22 +83,14 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(1, srcIP, dstIP, 0, 0, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
if t.stats != nil {
key := makeConnKey(srcIP, dstIP, 0, 0)
t.stats.TrackPacket(1, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8, packetData []byte) bool {
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
@@ -127,11 +115,6 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
return false
}
if t.stats != nil {
key := makeConnKey(srcIP, dstIP, 0, 0)
t.stats.TrackPacket(1, false, uint64(len(packetData)), true, key)
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&

View File

@@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -15,12 +15,12 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0, nil)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
}
})
}

View File

@@ -1,172 +0,0 @@
package conntrack
import (
"net"
"slices"
"sync"
"sync/atomic"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
// Stats represents connection tracking statistics
type Stats struct {
TotalConnsCreated atomic.Uint64
TotalConnsTimedOut atomic.Uint64
TotalPacketsDropped atomic.Uint64
ActiveConns atomic.Int64
TCPConns atomic.Int64
UDPConns atomic.Int64
ICMPConns atomic.Int64
TCPStateStats struct {
SynReceived atomic.Uint64
Established atomic.Uint64
FinWait atomic.Uint64
TimeWait atomic.Uint64
InvalidStates atomic.Uint64
}
PacketStats struct {
TCPPackets atomic.Uint64
UDPPackets atomic.Uint64
ICMPPackets atomic.Uint64
}
flowMutex sync.RWMutex
flows map[ConnKey]*fw.FlowStats
}
// NewStats creates a new Stats instance
func NewStats() *Stats {
return &Stats{
flows: make(map[ConnKey]*fw.FlowStats),
}
}
// TrackNewConnection records a new connection
func (s *Stats) TrackNewConnection(proto uint8, srcIP net.IP, dstIP net.IP, srcPort, dstPort uint16, direction fw.Direction) {
s.TotalConnsCreated.Add(1)
s.ActiveConns.Add(1)
switch proto {
case 6: // TCP
s.TCPConns.Add(1)
case 17: // UDP
s.UDPConns.Add(1)
case 1: // ICMP
s.ICMPConns.Add(1)
}
flow := &fw.FlowStats{
StartTime: time.Now(),
LastSeen: time.Now(),
Protocol: proto,
Direction: direction,
SourceIP: slices.Clone(srcIP),
DestIP: slices.Clone(dstIP),
SourcePort: srcPort,
DestPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
s.flowMutex.Lock()
s.flows[key] = flow
s.flowMutex.Unlock()
}
// TrackConnectionClosed records a connection closure
func (s *Stats) TrackConnectionClosed(proto uint8, timedOut bool, key ConnKey) {
s.ActiveConns.Add(-1)
if timedOut {
s.TotalConnsTimedOut.Add(1)
}
switch proto {
case 6: // TCP
s.TCPConns.Add(-1)
case 17: // UDP
s.UDPConns.Add(-1)
case 1: // ICMP
s.ICMPConns.Add(-1)
}
s.flowMutex.Lock()
delete(s.flows, key)
s.flowMutex.Unlock()
}
// TrackPacket records packet statistics
func (s *Stats) TrackPacket(proto uint8, dropped bool, bytes uint64, isInbound bool, key ConnKey) {
if dropped {
s.TotalPacketsDropped.Add(1)
return
}
switch proto {
case 6: // TCP
s.PacketStats.TCPPackets.Add(1)
case 17: // UDP
s.PacketStats.UDPPackets.Add(1)
case 1: // ICMP
s.PacketStats.ICMPPackets.Add(1)
}
s.flowMutex.RLock()
if flow, exists := s.flows[key]; exists {
if isInbound {
flow.BytesIn.Add(bytes)
flow.PacketsIn.Add(1)
} else {
flow.BytesOut.Add(bytes)
flow.PacketsOut.Add(1)
}
flow.LastSeen = time.Now()
}
s.flowMutex.RUnlock()
}
// TrackTCPState updates TCP state statistics
func (s *Stats) TrackTCPState(newState TCPState) {
switch newState {
case TCPStateSynReceived:
s.TCPStateStats.SynReceived.Add(1)
case TCPStateEstablished:
s.TCPStateStats.Established.Add(1)
case TCPStateFinWait1, TCPStateFinWait2:
s.TCPStateStats.FinWait.Add(1)
case TCPStateTimeWait:
s.TCPStateStats.TimeWait.Add(1)
default:
s.TCPStateStats.InvalidStates.Add(1)
}
}
// GetFlowSnapshot returns a copy of current flow statistics if enabled
func (s *Stats) GetFlowSnapshot() []*fw.FlowStats {
s.flowMutex.RLock()
defer s.flowMutex.RUnlock()
snapshot := make([]*fw.FlowStats, 0, len(s.flows))
for _, flow := range s.flows {
snapshot = append(snapshot, flow.Clone())
}
return snapshot
}
// CleanupFlows removes flow entries older than the specified duration if enabled
func (s *Stats) CleanupFlows(maxAge time.Duration) {
threshold := time.Now().Add(-maxAge)
s.flowMutex.Lock()
defer s.flowMutex.Unlock()
for key, flow := range s.flows {
if flow.LastSeen.Before(threshold) {
delete(s.flows, key)
}
}
}

View File

@@ -6,8 +6,6 @@ import (
"net"
"sync"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
@@ -64,6 +62,7 @@ type TCPConnKey struct {
type TCPConnTrack struct {
BaseConnTrack
State TCPState
sync.RWMutex
}
// TCPTracker manages TCP connection states
@@ -74,18 +73,16 @@ type TCPTracker struct {
done chan struct{}
timeout time.Duration
ipPool *PreallocatedIPs
stats *Stats
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, stats *Stats) *TCPTracker {
func NewTCPTracker(timeout time.Duration) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
@@ -93,13 +90,15 @@ func NewTCPTracker(timeout time.Duration, stats *Stats) *TCPTracker {
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
@@ -117,69 +116,24 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(6, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
oldState := conn.State
t.updateState(conn, flags, true)
if oldState != conn.State && t.stats != nil {
t.stats.TrackTCPState(conn.State)
}
conn.Unlock()
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) bool {
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
// Handle new SYN packets
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.Lock()
if _, exists := t.connections[key]; !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, dstIP)
copyIP(dstIPCopy, srcIP)
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: dstPort,
DestPort: srcPort,
},
State: TCPStateSynReceived,
}
conn.lastSeen.Store(time.Now().UnixNano())
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
}
return true
}
// Look up existing connection
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
@@ -188,15 +142,10 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
}
// Handle RST packets
if flags&TCPRst != 0 {
conn.Lock()
isEstablished := conn.IsEstablished()
if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
@@ -206,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
// Update state
conn.Lock()
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
@@ -329,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
}

View File

@@ -9,7 +9,7 @@ import (
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, nil)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
},
{
@@ -122,14 +122,11 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Verify connection is closed
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
t.Helper()
require.False(t, valid, "Data should be blocked after RST")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{
@@ -141,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@@ -157,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, nil)
tracker = NewTCPTracker(DefaultTCPTimeout)
tt.test(t)
})
}
@@ -165,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@@ -184,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: true,
desc: "Should accept RST for established connection",
@@ -198,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: false,
desc: "Should reject RST without connection",
@@ -226,17 +223,17 @@ func TestRSTHandling(t *testing.T) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -244,12 +241,12 @@ func BenchmarkTCPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -257,17 +254,17 @@ func BenchmarkTCPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, nil)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -277,9 +274,9 @@ func BenchmarkTCPTracker(b *testing.B) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, nil)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
@@ -290,14 +287,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire

View File

@@ -4,8 +4,6 @@ import (
"net"
"sync"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
@@ -28,11 +26,10 @@ type UDPTracker struct {
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
stats *Stats
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker {
func NewUDPTracker(timeout time.Duration) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
@@ -43,7 +40,6 @@ func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker {
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
@@ -51,7 +47,7 @@ func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker {
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) {
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
@@ -74,21 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(17, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
if t.stats != nil {
t.stats.TrackPacket(17, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) bool {
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
@@ -103,10 +92,6 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
if t.stats != nil {
t.stats.TrackPacket(17, false, uint64(len(packetData)), true, key)
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&

View File

@@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, nil)
tracker := NewUDPTracker(tt.timeout)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
@@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
@@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, nil)
tracker := NewUDPTracker(1 * time.Second)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@@ -72,7 +72,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct {
name string
@@ -144,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, nil)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got)
})
}
@@ -189,7 +189,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, nil)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
}
// Verify initial connections
@@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -219,12 +219,12 @@ func BenchmarkUDPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@@ -232,12 +232,12 @@ func BenchmarkUDPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, nil)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), nil)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
}
})
}

View File

@@ -22,10 +22,7 @@ import (
const layerTypeAll = 0
const (
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
EnvEnableStats = "NB_ENABLE_CONNTRACK_STATS"
)
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
@@ -55,9 +52,6 @@ type Manager struct {
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
statsEnabled bool
stats *conntrack.Stats
}
// decoder for packages
@@ -90,7 +84,6 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
func create(iface IFaceMapper) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
enableStats, _ := strconv.ParseBool(os.Getenv(EnvEnableStats))
m := &Manager{
decoders: sync.Pool{
@@ -110,21 +103,15 @@ func create(iface IFaceMapper) (*Manager, error) {
incomingRules: make(map[string]RuleSet),
wgIface: iface,
stateful: !disableConntrack,
statsEnabled: enableStats,
}
if enableStats {
m.stats = conntrack.NewStats()
log.Info("connection tracking statistics enabled")
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.stats)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.stats)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.stats)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if err := iface.SetFilter(m); err != nil {
@@ -317,10 +304,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.udpTracker.TrackOutbound(srcIP, dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
packetData)
m.trackUDPOutbound(d, srcIP, dstIP)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
@@ -329,16 +313,9 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
if m.stateful {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.tcpTracker.TrackOutbound(srcIP, dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
packetData)
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP,
d.icmp4.Id,
d.icmp4.Seq,
packetData)
m.trackICMPOutbound(d, srcIP, dstIP)
}
}
@@ -356,6 +333,17 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
}
}
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8
if tcp.SYN {
@@ -379,6 +367,15 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
m.udpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
@@ -392,6 +389,17 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
)
}
}
// dropFilter implements filtering logic for incoming packets
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
m.mutex.RLock()
@@ -415,7 +423,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
}
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, packetData) {
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false
}
@@ -439,7 +447,7 @@ func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
@@ -448,7 +456,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, pack
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
packetData,
)
case layers.LayerTypeUDP:
@@ -457,7 +464,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, pack
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
packetData,
)
case layers.LayerTypeICMPv4:
@@ -467,7 +473,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, pack
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
packetData,
)
// TODO: ICMPv6
@@ -607,11 +612,3 @@ func (m *Manager) RemovePacketHook(hookID string) error {
}
return fmt.Errorf("hook with given id not found")
}
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
if m.stats == nil {
return nil
}
return m.stats.GetFlowSnapshot()
}

View File

@@ -5,7 +5,6 @@ import (
"math/rand"
"net"
"os"
"strconv"
"strings"
"testing"
@@ -966,114 +965,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}
}
func BenchmarkFirewallStats(b *testing.B) {
scenarios := []struct {
name string
stats bool
longLived bool
conns int
}{
{"nostats_short_100", false, false, 100},
{"stats_short_100", true, false, 100},
{"nostats_long_100", false, true, 100},
{"stats_long_100", true, true, 100},
{"nostats_short_1000", false, false, 1000},
{"stats_short_1000", true, false, 1000},
{"nostats_long_1000", false, true, 1000},
{"stats_long_1000", true, true, 1000},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
b.Setenv(EnvEnableStats, strconv.FormatBool(sc.stats))
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Generate test IPs
srcIPs := make([]net.IP, sc.conns)
dstIPs := make([]net.IP, sc.conns)
for i := 0; i < sc.conns; i++ {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
// Pre-generate packets
inPackets := make([][]byte, sc.conns)
outPackets := make([][]byte, sc.conns)
for i := 0; i < sc.conns; i++ {
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
if sc.longLived {
// Establish connection
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
synAck := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(syn)
manager.dropFilter(synAck, manager.incomingRules)
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
connIdx := i % sc.conns
if !sc.longLived {
// New connection each time
syn := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPSyn))
synAck := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
80, uint16(1024+connIdx), uint16(conntrack.TCPSyn|conntrack.TCPAck))
ack := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(syn)
manager.dropFilter(synAck, manager.incomingRules)
manager.processOutgoingHooks(ack)
}
// Data transfer
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
if !sc.longLived {
// Tear down
finClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPFin|conntrack.TCPAck))
ackServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
80, uint16(1024+connIdx), uint16(conntrack.TCPAck))
finServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
80, uint16(1024+connIdx), uint16(conntrack.TCPFin|conntrack.TCPAck))
ackClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(finClient)
manager.dropFilter(ackServer, manager.incomingRules)
manager.dropFilter(finServer, manager.incomingRules)
manager.processOutgoingHooks(ackClient)
}
}
})
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()

View File

@@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
@@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{

View File

@@ -23,7 +23,7 @@ import (
"google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/firewall"
firewallmanager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
@@ -158,7 +158,7 @@ type Engine struct {
statusRecorder *peer.Status
firewall firewallmanager.Manager
firewall manager.Manager
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
@@ -1576,14 +1576,6 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil
}
// GetFirewallStats returns the firewall stats
func (e *Engine) GetFirewallStats() []*firewallmanager.FlowStats {
if e.firewall != nil {
return e.firewall.CollectStats()
}
return nil
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
if !enabled {

View File

@@ -24,7 +24,6 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -45,7 +44,6 @@ iptables.txt: Anonymized iptables rules with packet counters, if --system-info f
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
firewall_stats.json: Anonymized firewall statistics of the NetBird client.
state.json: Anonymized client state dump containing netbird states.
@@ -183,33 +181,21 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
}
if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
}
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
}
s.addSystemInfo(req, anonymizer, archive)
}
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
return fmt.Errorf("add network map: %w", err)
}
if err := s.addFirewallStats(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall stats to debug bundle: %v", err)
}
if err := s.addStateFile(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add state file to debug bundle: %v", err)
}
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
if s.logFile != "console" {
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
}
if err := archive.Close(); err != nil {
@@ -218,6 +204,20 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
return nil
}
func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
}
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
}
}
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if req.GetAnonymize() {
readmeReader := strings.NewReader(readmeContent)
@@ -358,32 +358,6 @@ func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonym
return nil
}
func (s *Server) addFirewallStats(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if s.connectClient == nil || s.connectClient.Engine() == nil {
return nil
}
stats := s.connectClient.Engine().GetFirewallStats()
if stats == nil {
return nil
}
if req.GetAnonymize() {
anonymizeStatIPs(stats, anonymizer)
}
jsonBytes, err := json.MarshalIndent(stats, "", " ")
if err != nil {
return fmt.Errorf("marshal firewall stats: %w", err)
}
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "firewall_stats.json"); err != nil {
return fmt.Errorf("add firewall stats to zip: %w", err)
}
return nil
}
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
@@ -968,18 +942,3 @@ func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
}
return v
}
func anonymizeStatIPs(stats []*firewall.FlowStats, anonymizer *anonymize.Anonymizer) {
for _, stat := range stats {
if stat.SourceIP != nil {
if ip, ok := netip.AddrFromSlice(stat.SourceIP); ok {
stat.SourceIP = anonymizer.AnonymizeIP(ip).AsSlice()
}
}
if stat.DestIP != nil {
if ip, ok := netip.AddrFromSlice(stat.DestIP); ok {
stat.DestIP = anonymizer.AnonymizeIP(ip).AsSlice()
}
}
}
}

View File

@@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
@@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r
}
}

View File

@@ -725,10 +725,6 @@ components:
PolicyRuleMinimum:
type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name:
description: Policy rule name identifier
type: string
@@ -790,6 +786,31 @@ components:
- end
PolicyRuleUpdate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources:
description: Policy rule source group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv797"
destinations:
description: Policy rule destination group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9h7v7m0"
required:
- sources
- destinations
PolicyRuleCreate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
@@ -817,6 +838,10 @@ components:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources:
description: Policy rule source group IDs
type: array
@@ -836,10 +861,6 @@ components:
PolicyMinimum:
type: object
properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name:
description: Policy name identifier
type: string
@@ -854,7 +875,6 @@ components:
example: true
required:
- name
- description
- enabled
PolicyUpdate:
allOf:
@@ -874,11 +894,33 @@ components:
$ref: '#/components/schemas/PolicyRuleUpdate'
required:
- rules
PolicyCreate:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
items:
type: string
example: "chacdk86lnnboviihd70"
rules:
description: Policy rule object for policy UI editor
type: array
items:
$ref: '#/components/schemas/PolicyRuleUpdate'
required:
- rules
Policy:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
@@ -2463,7 +2505,7 @@ paths:
content:
'application/json':
schema:
$ref: '#/components/schemas/PolicyUpdate'
$ref: '#/components/schemas/PolicyCreate'
responses:
'200':
description: A Policy object

View File

@@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct {
// Policy defines model for Policy.
type Policy struct {
// Description Policy friendly description
Description string `json:"description"`
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
@@ -897,16 +897,31 @@ type Policy struct {
SourcePostureChecks []string `json:"source_posture_checks"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// PolicyCreate defines model for PolicyCreate.
type PolicyCreate struct {
// Description Policy friendly description
Description string `json:"description"`
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Id Policy ID
Id *string `json:"id,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
// Rules Policy rule object for policy UI editor
Rules []PolicyRuleUpdate `json:"rules"`
// SourcePostureChecks Posture checks ID's applied to policy source groups
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// Description Policy friendly description
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Name Policy name identifier
Name string `json:"name"`
@@ -970,9 +985,6 @@ type PolicyRuleMinimum struct {
// Enabled Policy rule status
Enabled bool `json:"enabled"`
// Id Policy rule ID
Id *string `json:"id,omitempty"`
// Name Policy rule name identifier
Name string `json:"name"`
@@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string
// PolicyUpdate defines model for PolicyUpdate.
type PolicyUpdate struct {
// Description Policy friendly description
Description string `json:"description"`
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Id Policy ID
Id *string `json:"id,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
@@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest
type PostApiPoliciesJSONRequestBody = PolicyUpdate
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate

View File

@@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
return
}
description := ""
if req.Description != nil {
description = *req.Description
}
policy := &types.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
Description: description,
}
for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
if rule.Id != nil && policyID != "" {
ruleID = *rule.Id
}
@@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
ap := &api.Policy{
Id: &policy.ID,
Name: policy.Name,
Description: policy.Description,
Description: &policy.Description,
Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks,
}

View File

@@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
func TestPoliciesWritePolicy(t *testing.T) {
str := func(s string) *string { return &s }
emptyString := ""
tt := []struct {
name string
expectedStatus int
@@ -184,8 +185,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPolicy: &api.Policy{
Id: str("id-was-set"),
Name: "Default POSTed Policy",
Id: str("id-was-set"),
Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{
{
Id: str("id-was-set"),
@@ -232,8 +234,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPolicy: &api.Policy{
Id: str("id-existed"),
Name: "Default POSTed Policy",
Id: str("id-existed"),
Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{
{
Id: str("id-existed"),

View File

@@ -1319,6 +1319,18 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
}
}
if !addSourcePeers {
var peerPostureChecks []string
for _, policy := range resourcePolicies[resource.ID] {
peerPostureChecks = append(peerPostureChecks, policy.SourcePostureChecks...)
}
isValid := a.validatePostureChecksOnPeer(ctx, peerPostureChecks, peerID)
if !isValid {
continue
}
}
for _, policy := range resourcePolicies[resource.ID] {
for _, sourceGroup := range policy.SourceGroups() {
group := a.GetGroup(sourceGroup)