Compare commits

..

6 Commits

Author SHA1 Message Date
Pedro Costa
7b02e9c3a8 [management] base manager 2025-03-05 09:52:18 +00:00
hakansa
eee90fbbbf [client] UI Refactor Icon Paths (#3420)
[client] UI Refactor Icon Paths (#3420)
2025-03-05 09:47:29 +00:00
Viktor Liu
85aea0a030 [client] Close userspace firewall properly (#3426) 2025-03-05 09:47:27 +00:00
robertgro
e41acdb9ac [client] Add Netbird GitHub link to the client ui about sub menu (#3372) 2025-03-05 09:46:21 +00:00
Philippe Vaucher
54dc27abec [client Fix env var typo (#3415) 2025-03-05 09:46:21 +00:00
Bethuel Mmbaga
fdd7dc67c0 [management] Handle transaction error on peer deletion (#3387)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-05 09:46:21 +00:00
165 changed files with 1636 additions and 9283 deletions

View File

@@ -31,22 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**Is any other VPN software installed?**
**NetBird status -dA output:**
If yes, which one?
If applicable, add the `netbird status -dA' command output.
**Debug output**
**Do you face any (non-mobile) client issues?**
To help us resolve the problem, please attach the following debug output
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots**
@@ -55,10 +47,3 @@ If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -258,7 +258,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
@@ -325,8 +325,8 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -392,7 +392,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
@@ -461,7 +461,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:

View File

@@ -134,11 +134,10 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
run := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
run <- err
}
}()
@@ -148,9 +147,13 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
case err := <-run:
if err != nil {
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
}
c.connect = client

View File

@@ -4,13 +4,12 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() device.WGAddress
IsUserspaceBind() bool
SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),

View File

@@ -4,20 +4,21 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),

View File

@@ -3,20 +3,21 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}

View File

@@ -4,11 +4,11 @@ package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -17,23 +17,26 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {

View File

@@ -3,7 +3,6 @@ package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
@@ -22,13 +21,13 @@ const (
firewallRuleName = "Netbird"
)
// Reset firewall to the default state
// Close closes the firewall manager
func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -45,8 +44,8 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {

View File

@@ -3,14 +3,14 @@ package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
Address() iface.WGAddress
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -2,6 +2,7 @@ package conntrack
import (
"fmt"
"net"
"net/netip"
"sync/atomic"
"time"
@@ -13,15 +14,13 @@ import (
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64
}
// these small methods will be inlined by the compiler
@@ -31,17 +30,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
@@ -64,3 +52,16 @@ type ConnKey struct {
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -2,7 +2,7 @@ package conntrack
import (
"context"
"net/netip"
"net"
"testing"
"github.com/sirupsen/logrus"
@@ -12,7 +12,7 @@ import (
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
@@ -21,22 +21,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})
@@ -46,22 +46,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})

View File

@@ -1,8 +1,8 @@
package conntrack
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
@@ -23,13 +23,14 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
ID uint16
SrcIP netip.Addr
DstIP netip.Addr
Sequence uint16
ID uint16
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence)
}
// ICMPConnTrack represents an ICMP connection state
@@ -45,8 +46,8 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
flowLogger nftypes.FlowLogger
}
@@ -56,27 +57,21 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout = DefaultICMPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
}
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
key := makeICMPKey(srcIP, dstIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -84,7 +79,6 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
@@ -93,21 +87,22 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
// TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists {
return
}
@@ -117,7 +112,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
t.sendStartEvent(direction, key, typ, code)
return
}
@@ -125,8 +120,8 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
},
ICMPType: typ,
ICMPCode: code,
@@ -138,20 +133,16 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -162,19 +153,16 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -188,58 +176,56 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
t.connections = nil
t.mutex.Unlock()
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
})
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"net/netip"
"net"
"testing"
)
@@ -10,12 +10,12 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0)
}
})
@@ -23,17 +23,17 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
}
})
}

View File

@@ -3,8 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"context"
"net/netip"
"net"
"sync"
"sync/atomic"
"time"
@@ -89,8 +88,6 @@ const (
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
@@ -123,7 +120,7 @@ type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
done chan struct{}
timeout time.Duration
flowLogger nftypes.FlowLogger
}
@@ -134,28 +131,21 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout = DefaultTCPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
timeout: timeout,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -164,10 +154,9 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
@@ -175,36 +164,37 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
conn.established.Store(false)
conn.tombstone.Store(false)
@@ -215,17 +205,12 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -246,15 +231,15 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.sendEvent(nftypes.TypeEnd, key, conn)
return true
}
conn.Lock()
t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
@@ -264,8 +249,6 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
state := conn.State
defer func() {
if state != conn.State {
@@ -304,24 +287,17 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait2:
@@ -329,7 +305,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
case TCPStateClosing:
@@ -338,7 +314,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
case TCPStateCloseWait:
@@ -352,7 +328,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.SetTombstone()
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.sendEvent(nftypes.TypeEnd, key, conn)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
}
@@ -399,14 +375,12 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false
}
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -437,11 +411,11 @@ func (t *TCPTracker) cleanup() {
// Return IPs to pool
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
@@ -449,7 +423,8 @@ func (t *TCPTracker) cleanup() {
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs
t.mutex.Lock()
@@ -471,20 +446,15 @@ func isValidFlagCombination(flags uint8) bool {
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"net/netip"
"net"
"testing"
"time"
@@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
},
{
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: true,
desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: false,
desc: "Should reject RST without connection",
@@ -208,12 +208,7 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST()
// Verify connection state is as expected
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
@@ -225,15 +220,15 @@ func TestRSTHandling(t *testing.T) {
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}
func BenchmarkTCPTracker(b *testing.B) {
@@ -241,12 +236,12 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
@@ -254,17 +249,17 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
@@ -272,16 +267,16 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
@@ -296,10 +291,10 @@ func BenchmarkCleanup(b *testing.B) {
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire

View File

@@ -1,8 +1,7 @@
package conntrack
import (
"context"
"net/netip"
"net"
"sync"
"time"
@@ -22,8 +21,6 @@ const (
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
}
// UDPTracker manages UDP connection states
@@ -32,8 +29,8 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
flowLogger nftypes.FlowLogger
}
@@ -43,41 +40,34 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout = DefaultUDPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -85,7 +75,6 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
@@ -93,21 +82,21 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
@@ -116,17 +105,12 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -137,20 +121,17 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -164,16 +145,16 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
t.connections = nil
@@ -181,16 +162,11 @@ func (t *UDPTracker) Close() {
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
return conn, exists
}
@@ -200,20 +176,15 @@ func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"context"
"net"
"net/netip"
"testing"
"time"
@@ -35,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.tickerCancel)
assert.NotNil(t, tracker.done)
})
}
}
@@ -49,15 +49,10 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
// Verify connection was tracked
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
@@ -71,18 +66,18 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
sleep time.Duration
@@ -99,7 +94,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
},
{
name: "invalid source IP",
srcIP: netip.MustParseAddr("192.168.1.4"),
srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
@@ -109,7 +104,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
@@ -149,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got)
})
}
@@ -160,45 +155,42 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
tickerCancel: tickerCancel,
done: make(chan struct{}),
logger: logger,
flowLogger: flowLogger,
}
// Start cleanup routine
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
// Add some connections
connections := []struct {
srcIP netip.Addr
dstIP netip.Addr
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
}{
{
srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"),
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"),
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
}
// Verify initial connections
@@ -223,12 +215,12 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
}
})
@@ -236,17 +228,17 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
}
})
}

View File

@@ -15,16 +15,13 @@ import (
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
// Extract ICMP header to get type and code
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
@@ -36,6 +33,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
@@ -43,15 +42,30 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
@@ -59,20 +73,21 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return
return true
}
response := make([]byte, f.endpoint.mtu)
@@ -81,7 +96,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return
return true
}
ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -102,11 +117,13 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return
return true
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return true
}
// sendICMPEvent stores flow events for ICMP packets
@@ -117,11 +134,9 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
}

View File

@@ -22,14 +22,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
@@ -58,7 +51,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
@@ -74,7 +66,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
}()
// Create context for managing the proxy goroutines
@@ -106,27 +98,17 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
Protocol: 6,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}

View File

@@ -89,6 +89,21 @@ func (f *udpForwarder) Stop() {
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
@@ -125,6 +140,8 @@ func (f *udpForwarder) cleanup() {
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
}
}
}
@@ -148,19 +165,13 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message
return
}
@@ -173,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
@@ -200,14 +212,13 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
@@ -227,7 +238,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
}()
errChan := make(chan error, 2)
@@ -253,30 +264,19 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
Protocol: 17, // UDP protocol number
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
func (c *udpPacketConn) updateLastSeen() {

View File

@@ -3,7 +3,6 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
@@ -32,9 +31,13 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
@@ -119,12 +122,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
if ipv4 := ip.To4(); ipv4 != nil {
return m.checkBitmapBit(ipv4)
}
return false

View File

@@ -2,91 +2,90 @@ package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr wgaddr.Address
testIP netip.Addr
setupAddr iface.WGAddress
testIP net.IP
expected bool
}{
{
name: "Localhost range",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
testIP: net.ParseIP("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
testIP: net.ParseIP("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
testIP: net.ParseIP("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
testIP: net.ParseIP("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
testIP: net.ParseIP("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
testIP: net.ParseIP("fe80::1"),
expected: false,
},
}
@@ -96,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
AddressFunc: func() iface.WGAddress {
return tt.setupAddr
},
}
@@ -175,7 +174,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
result := manager.IsLocalIP(net.ParseIP(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"net"
"net/netip"
"github.com/google/gopacket"
@@ -12,7 +13,7 @@ import (
type PeerRule struct {
id string
mgmtId []byte
ip netip.Addr
ip net.IP
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType

View File

@@ -2,7 +2,7 @@ package uspfilter
import (
"fmt"
"net/netip"
"net"
"time"
"github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
}
type PacketTrace struct {
SourceIP netip.Addr
DestinationIP netip.Addr
SourceIP net.IP
DestinationIP net.IP
Protocol string
SourcePort uint16
DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
}
type PacketBuilder struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcIP net.IP
DstIP net.IP
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
SrcIP: p.SrcIP,
DstIP: p.DstIP,
}
}
@@ -260,30 +260,28 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.localipmanager.IsLocalIP(dstIP) {
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter.Load() {
if m.nativeRouter {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
@@ -311,46 +309,39 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
strRuleId := "implicit"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true
}
trace.AddResult(StagePeerACL, msg, true)
trace.AddResult(StagePeerACL, msg, !blocked)
// Handle netstack mode
if m.netstack {
switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
}
// In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled.Load() {
if !m.routingEnabled {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
@@ -366,14 +357,14 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {
strId = "<no id>"
strId = "implicit"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
@@ -382,7 +373,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
if allowed && m.forwarder != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
@@ -401,7 +392,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
dropped := m.processOutgoingHooks(packetData)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -1,440 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

View File

@@ -10,7 +10,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -67,9 +66,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager
type Manager struct {
// outgoingRules is used for hooks only
outgoingRules map[netip.Addr]RuleSet
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet
incomingRules map[string]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool
@@ -81,9 +80,9 @@ type Manager struct {
// indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled atomic.Bool
routingEnabled bool
// indicates whether we leave forwarding and filtering to the native firewall
nativeRouter atomic.Bool
nativeRouter bool
// indicates whether we track outbound connections
stateful bool
// indicates whether wireguards runs in netstack mode
@@ -96,7 +95,7 @@ type Manager struct {
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder]
forwarder *forwarder.Forwarder
logger *nblog.Logger
flowLogger nftypes.FlowLogger
}
@@ -169,18 +168,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
}
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
@@ -212,7 +211,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder.Load() == nil {
if m.forwarder == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
@@ -256,20 +255,20 @@ func (m *Manager) determineRouting() error {
switch {
case disableUspRouting:
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
m.routingEnabled = false
m.nativeRouter = false
log.Info("userspace routing is disabled")
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
m.routingEnabled = true
m.nativeRouter = true
log.Info("server routes are disabled")
case forceUserspaceRouter:
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing is forced")
@@ -277,19 +276,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
m.routingEnabled = true
m.nativeRouter = true
log.Info("native routing is enabled")
default:
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing enabled by default")
}
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
if m.routingEnabled && !m.nativeRouter {
return m.initForwarder()
}
@@ -298,24 +297,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error {
if m.forwarder.Load() != nil {
if m.forwarder != nil {
return nil
}
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice()
if intf == nil {
m.routingEnabled.Store(false)
m.routingEnabled = false
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil {
m.routingEnabled.Store(false)
m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err)
}
m.forwarder.Store(forwarder)
m.forwarder = forwarder
log.Debug("forwarder initialized")
@@ -331,7 +330,7 @@ func (m *Manager) IsServerRouteSupported() bool {
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
}
@@ -342,7 +341,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair)
}
return nil
@@ -361,23 +360,17 @@ func (m *Manager) AddPeerFiltering(
action firewall.Action,
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: i,
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if i.Is4() {
if ipNormalized := ip.To4(); ipNormalized != nil {
r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@@ -402,10 +395,10 @@ func (m *Manager) AddPeerFiltering(
}
m.mutex.Lock()
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(RuleSet)
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip][r.id] = r
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
@@ -419,10 +412,13 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
@@ -436,16 +432,14 @@ func (m *Manager) AddRouteFiltering(
action: action,
}
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
}
@@ -474,10 +468,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip], r.id)
delete(m.incomingRules[r.ip.String()], r.id)
return nil
}
@@ -510,13 +504,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData)
}
// UpdateLocalIPs updates the list of local IPs
@@ -524,7 +518,10 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -537,34 +534,31 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
}
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
if srcIP == nil {
return false
}
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
// Track all protocols if stateful mode is enabled
if m.stateful {
m.trackOutbound(d, srcIP, dstIP)
}
if m.stateful {
m.trackOutbound(d, srcIP, dstIP, size)
// Process UDP hooks even if stateful mode is disabled
if d.decoded[1] == layers.LayerTypeUDP {
return m.checkUDPHooks(d, dstIP, packetData)
}
return false
}
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
return d.ip4.SrcIP, d.ip4.DstIP
case layers.LayerTypeIPv6:
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
return d.ip6.SrcIP, d.ip6.DstIP
default:
return netip.Addr{}, netip.Addr{}
return nil, nil
}
}
@@ -591,70 +585,51 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode)
}
}
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode)
}
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.udpHook(packetData)
}
}
}
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, size int) bool {
func (m *Manager) dropFilter(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -663,19 +638,19 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
}
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
if srcIP == nil {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false
}
if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
@@ -683,28 +658,27 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
if blocked {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
RuleID: ruleId,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: srcAddr,
DestIP: dstAddr,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true
}
@@ -715,7 +689,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
}
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP, ruleID, size)
m.trackInbound(d, srcIP, dstIP)
return false
}
@@ -726,12 +700,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
return false
}
if m.forwarder.Load() == nil {
if m.forwarder == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true
}
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
@@ -741,34 +715,37 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
if !m.routingEnabled {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
// Pass to native stack if native router is enabled or forced
if m.nativeRouter.Load() {
if m.nativeRouter {
return false
}
proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
id, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
RuleID: id,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: srcAddr,
DestIP: dstAddr,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
@@ -777,7 +754,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
}
// Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err)
}
@@ -822,7 +799,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
@@ -831,7 +808,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
size,
)
case layers.LayerTypeUDP:
@@ -840,7 +816,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
size,
)
case layers.LayerTypeICMPv4:
@@ -848,8 +823,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
size,
)
// TODO: ICMPv6
@@ -869,22 +844,20 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) {
if m.isSpecialICMP(d) {
return nil, false
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
return mgmtId, filter
}
@@ -909,10 +882,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
if rule.matchByIP && !ip.Equal(rule.ip) {
continue
}
@@ -946,13 +919,16 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return nil, false, false
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
// routeACLsPass returns treu if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
@@ -996,7 +972,9 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
@@ -1006,22 +984,23 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
udpHook: hook,
}
if ip.Is4() {
if ip.To4() != nil {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
m.outgoingRules[r.ip.String()][r.id] = r
}
m.mutex.Unlock()
return r.id
@@ -1069,21 +1048,20 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
fwder := m.forwarder.Load()
if fwder == nil {
if m.forwarder == nil {
return nil
}
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
m.routingEnabled = false
m.nativeRouter = false
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}
fwder.Stop()
m.forwarder.Store(nil)
m.forwarder.Stop()
m.forwarder = nil
log.Debug("forwarder stopped")

View File

@@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
// Test packet
@@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut, 0)
manager.processOutgoingHooks(testOut)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0)
manager.dropFilter(testIn)
}
})
}
@@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
// Prepare test packets simulating bidirectional traffic
@@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx])
}
})
}
@@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
// Data transfer
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
// Connection teardown
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
}
})
}
@@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
// Pre-generate test packets
@@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx])
}
})
})
@@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
}
})
})
@@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestPeerACLFiltering(t *testing.T) {
@@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: localIP,
Network: wgNet,
}
@@ -192,7 +192,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.DropIncoming(packet)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.DropIncoming(packet)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: localIP,
Network: wgNet,
}
@@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load())
require.True(tb, manager.routingEnabled)
require.False(tb, manager.nativeRouter)
tb.Cleanup(func() {
require.NoError(tb, manager.Close(nil))
@@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule))
})
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
@@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
})
for i, p := range tc.packets {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
srcIP := net.ParseIP(p.srcIP)
dstIP := net.ParseIP(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)

View File

@@ -18,17 +18,17 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
@@ -54,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() wgaddr.Address {
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return wgaddr.Address{}
return iface.WGAddress{}
}
return i.AddressFunc()
}
@@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
ip := netip.MustParseAddr("192.168.1.1")
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
ip net.IP
dPort uint16
hook func([]byte) bool
expectedID string
@@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
ip: net.IPv4(10, 168, 0, 1),
dPort: 8000,
hook: func([]byte) bool { return true },
},
@@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
ip: net.IPv6loopback,
dPort: 9000,
hook: func([]byte) bool { return false },
},
@@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule
if tt.in {
if len(manager.incomingRules[tt.ip]) != 1 {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else {
@@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
}
if tt.ip.Compare(addedRule.ip) != 0 {
if !tt.ip.Equal(addedRule.ip) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
@@ -269,8 +269,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), 0) {
if m.dropFilter(buf.Bytes()) {
t.Errorf("expected packet to be accepted")
return
}
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
netip.MustParseAddr("100.10.0.100"),
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0)
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}
@@ -569,11 +569,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
@@ -636,12 +636,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
drop = manager.dropFilter(inboundBuf.Bytes())
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0)
drop = manager.dropFilter(testBuf.Bytes())
require.True(t, drop, tc.description)
})
}

View File

@@ -5,6 +5,7 @@ import (
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
@@ -13,8 +14,6 @@ import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type RecvMessage struct {
@@ -53,10 +52,9 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -66,7 +64,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -111,17 +108,35 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeIP)
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -146,10 +161,9 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -261,6 +275,21 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}

View File

@@ -17,8 +17,6 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
@@ -43,7 +41,6 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -67,7 +64,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
// embed UDPMux
@@ -122,7 +118,6 @@ type udpConn struct {
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -164,11 +159,6 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {

View File

@@ -9,14 +9,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice

View File

@@ -1,29 +1,29 @@
package wgaddr
package device
import (
"fmt"
"net"
)
// Address WireGuard parsed address
type Address struct {
// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) {
func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return Address{}, err
return WGAddress{}, err
}
return Address{
return WGAddress{
IP: ip,
Network: network,
}, nil
}
func (addr Address) String() string {
func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@@ -13,12 +13,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -32,7 +31,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -94,7 +93,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
@@ -124,7 +123,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *WGTunDevice) WgAddress() wgaddr.Address {
func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -13,12 +13,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -30,7 +29,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -86,7 +85,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -107,7 +106,7 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -2,7 +2,6 @@ package device
import (
"net"
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
@@ -11,16 +10,16 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
@@ -58,7 +57,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -82,7 +81,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) {
if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -14,12 +14,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
iceBind *bind.ICEBind
@@ -31,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -121,11 +120,11 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}

View File

@@ -14,13 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address wgaddr.Address
address WGAddress
wgPort int
key string
mtu int
@@ -35,7 +34,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -100,10 +99,9 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
@@ -114,7 +112,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -147,7 +145,7 @@ func (t *TunKernelDevice) Close() error {
return closErr
}
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -13,13 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunNetstackDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -35,7 +34,7 @@ type TunNetstackDevice struct {
net *netstack.Net
}
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -98,7 +97,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
@@ -117,7 +116,7 @@ func (t *TunNetstackDevice) Close() error {
return nil
}
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -12,12 +12,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type USPDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -29,7 +28,7 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
@@ -94,7 +93,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -114,7 +113,7 @@ func (t *USPDevice) Close() error {
return nil
}
func (t *USPDevice) WgAddress() wgaddr.Address {
func (t *USPDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -13,14 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -33,7 +32,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -119,7 +118,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -140,7 +139,7 @@ func (t *TunDevice) Close() error {
}
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -57,7 +56,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address WGAddress) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)

View File

@@ -8,8 +8,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -92,7 +90,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address WGAddress) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {

View File

@@ -7,14 +7,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -29,6 +28,8 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -71,7 +72,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
func (w *WGIface) Address() wgaddr.Address {
func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
@@ -102,7 +103,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := wgaddr.ParseWGAddress(newAddr)
addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}

View File

@@ -3,18 +3,17 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,

View File

@@ -6,18 +6,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -5,18 +5,17 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,13 +8,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
@@ -22,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -35,7 +34,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,17 +4,16 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -6,7 +6,6 @@ package mocks
import (
net "net"
"net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -36,7 +35,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
@@ -50,31 +49,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// RemovePacketHook mocks base method.

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
log "github.com/sirupsen/logrus"
@@ -17,13 +16,13 @@ import (
type ProxyBind struct {
Bind *bind.ICEBind
fakeNetIP *netip.AddrPort
wgBindEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
wgAddr *net.UDPAddr
wgEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
@@ -34,24 +33,20 @@ type ProxyBind struct {
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
if err != nil {
return err
}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.wgAddr = addr
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
return err
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
return p.wgAddr
}
func (p *ProxyBind) Work() {
@@ -59,8 +54,6 @@ func (p *ProxyBind) Work() {
return
}
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
@@ -100,7 +93,7 @@ func (p *ProxyBind) close() error {
p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
p.Bind.RemoveEndpoint(p.wgAddr)
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
@@ -133,7 +126,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
msg := bind.RecvMessage{
Endpoint: p.wgBindEndpoint,
Endpoint: p.wgEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
@@ -141,19 +134,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}

View File

@@ -22,8 +22,6 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True
######################################################################
@@ -70,9 +68,6 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH
@@ -85,36 +80,8 @@ Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand
Exch $1
Push $2
@@ -196,16 +163,6 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR"
@@ -229,10 +186,7 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
# wait the service uninstall take unblock the executable
Sleep 3000

View File

@@ -9,13 +9,13 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
@@ -49,7 +49,7 @@ func TestDefaultManager(t *testing.T) {
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
@@ -343,7 +343,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
}
// Address mocks base method.
func (m *MockIFaceMapper) Address() wgaddr.Address {
func (m *MockIFaceMapper) Address() iface.WGAddress {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(wgaddr.Address)
ret0, _ := ret[0].(iface.WGAddress)
return ret0
}

View File

@@ -61,7 +61,7 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}) error {
func (c *ConnectClient) Run(runningChan chan error) error {
return c.run(MobileDependency{}, runningChan)
}
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -159,9 +159,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.ctx.Err() != nil {
if c.isContextCancelled() {
return nil
}
@@ -281,11 +282,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
select {
case runningChan <- struct{}{}:
default:
}
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
}
<-engineCtx.Done()
@@ -379,6 +379,15 @@ func (c *ConnectClient) Stop() error {
return nil
}
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored

View File

@@ -22,7 +22,6 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -31,7 +30,7 @@ import (
"github.com/netbirdio/netbird/formatter"
)
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
type mocWGIface struct {
filter device.PacketFilter
@@ -41,9 +40,9 @@ func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() wgaddr.Address {
func (w *mocWGIface) Address() iface.WGAddress {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{
return iface.WGAddress{
IP: ip,
Network: network,
}
@@ -459,7 +458,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)

View File

@@ -2,7 +2,7 @@ package dns
import (
"fmt"
"net/netip"
"net"
"sync"
"github.com/google/gopacket"
@@ -117,10 +117,5 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true
}
ip, err := netip.ParseAddr(s.runtimeIP)
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}

View File

@@ -5,15 +5,15 @@ package dns
import (
"net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() device.PacketFilter

View File

@@ -1,15 +1,15 @@
package dns
import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice

View File

@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
// start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:])
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
@@ -1641,19 +1641,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.ctx.Err() != nil {
return
}
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client context, engine will be recreated")
log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
}
@@ -1665,17 +1662,34 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New()
go func() {
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
return
}
log.Errorf("network monitor error: %v", err)
return
}
var mu sync.Mutex
var debounceTimer *time.Timer
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}

View File

@@ -31,7 +31,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -76,7 +75,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
@@ -115,7 +114,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() wgaddr.Address {
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
@@ -365,8 +364,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error {
return nil
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -21,7 +20,7 @@ type wgIfaceBase interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
)
type rcvChan chan *types.EventFields
@@ -22,17 +21,15 @@ type Logger struct {
enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc
statusRecorder *peer.Status
Store types.Store
}
func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
func New(ctx context.Context) *Logger {
ctx, cancel := context.WithCancel(ctx)
return &Logger{
ctx: ctx,
cancel: cancel,
statusRecorder: statusRecorder,
Store: store.NewMemoryStore(),
ctx: ctx,
cancel: cancel,
Store: store.NewMemoryStore(),
}
}
@@ -61,14 +58,13 @@ func (l *Logger) startReceiver() {
if l.enabled.Load() {
return
}
l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel
l.mux.Unlock()
c := make(rcvChan, 100)
l.rcvChan.Store(&c)
l.rcvChan.Swap(&c)
l.enabled.Store(true)
for {
@@ -77,15 +73,12 @@ func (l *Logger) startReceiver() {
log.Info("flow Memory store receiver stopped")
return
case eventFields := <-c:
id := uuid.New()
id := uuid.NewString()
event := types.Event{
ID: id,
EventFields: *eventFields,
Timestamp: time.Now(),
}
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
event.SourceResourceID = []byte(srcResId)
event.DestResourceID = []byte(dstResId)
l.Store.StoreEvent(&event)
}
}
@@ -107,7 +100,6 @@ func (l *Logger) stop() {
l.cancelReceiver()
l.cancelReceiver = nil
}
l.rcvChan.Store(nil)
l.mux.Unlock()
}
@@ -115,7 +107,7 @@ func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents()
}
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
func (l *Logger) DeleteEvents(ids []string) {
l.Store.DeleteEvents(ids)
}

View File

@@ -12,7 +12,7 @@ import (
)
func TestStore(t *testing.T) {
logger := logger.New(context.Background(), nil)
logger := logger.New(context.Background())
logger.Enable()
event := types.EventFields{

View File

@@ -2,20 +2,17 @@ package netflow
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
)
@@ -32,8 +29,8 @@ type Manager struct {
}
// NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
flowLogger := logger.New(ctx, statusRecorder)
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte) *Manager {
flowLogger := logger.New(ctx)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -48,80 +45,46 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
}
}
// Update applies new flow configuration settings
// needsNewClient checks if a new client needs to be created
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
current := m.flowConfig
return previous == nil ||
!previous.Enabled ||
previous.TokenPayload != current.TokenPayload ||
previous.TokenSignature != current.TokenSignature ||
previous.URL != current.URL
}
// enableFlow starts components for flow tracking
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %s", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
}
m.logger.Enable()
if m.conntrack != nil {
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
return nil
}
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
// Update applies new flow configuration settings
func (m *Manager) Update(update *nftypes.FlowConfig) error {
if update == nil {
return nil
}
m.mux.Lock()
defer m.mux.Unlock()
previous := m.flowConfig
m.flowConfig = update
if update.Enabled {
return m.enableFlow(previous)
if m.conntrack != nil {
if err := m.conntrack.Start(update.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
m.logger.Enable()
if previous == nil || !previous.Enabled {
flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature)
if err != nil {
return err
}
log.Infof("flow client connected to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs()
go m.startSender()
}
return nil
}
return m.disableFlow()
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
if previous != nil && previous.Enabled {
return m.receiverClient.Close()
}
return nil
}
// Close cleans up all resources
@@ -132,13 +95,6 @@ func (m *Manager) Close() {
if m.conntrack != nil {
m.conntrack.Close()
}
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %s", err)
}
}
m.logger.Close()
}
@@ -150,7 +106,6 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
func (m *Manager) startSender() {
ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
@@ -158,62 +113,56 @@ func (m *Manager) startSender() {
case <-ticker.C:
events := m.logger.GetEvents()
for _, event := range events {
if err := m.send(event); err != nil {
log.Errorf("failed to send flow event to server: %s", err)
continue
log.Infof("send flow event to server: %s", event.ID)
err := m.send(event)
if err != nil {
log.Errorf("send flow event to server: %s", err)
}
log.Tracef("sent flow event: %s", event.ID)
}
}
}
}
func (m *Manager) receiveACKs(client *client.GRPCClient) {
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
log.Tracef("received flow event ack: %s", ack.EventId)
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
func (m *Manager) receiveACKs() {
if m.receiverClient == nil {
return
}
err := m.receiverClient.Receive(m.ctx, func(ack *proto.FlowEventAck) error {
log.Infof("receive flow event ack: %s", ack.EventId)
m.logger.DeleteEvents([]string{ack.EventId})
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("failed to receive flow event ack: %s", err)
if err != nil {
log.Errorf("receive flow event ack: %s", err)
}
}
func (m *Manager) send(event *nftypes.Event) error {
m.mux.Lock()
client := m.receiverClient
m.mux.Unlock()
if client == nil {
if m.receiverClient == nil {
return nil
}
return client.Send(toProtoEvent(m.publicKey, event))
return m.receiverClient.Send(m.ctx, toProtoEvent(m.publicKey, event))
}
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
protoEvent := &proto.FlowEvent{
EventId: event.ID[:],
EventId: event.ID,
Timestamp: timestamppb.New(event.Timestamp),
PublicKey: publicKey,
FlowFields: &proto.FlowFields{
FlowId: event.FlowID[:],
RuleId: event.RuleID,
Type: proto.Type(event.Type),
Direction: proto.Direction(event.Direction),
Protocol: uint32(event.Protocol),
SourceIp: event.SourceIP.AsSlice(),
DestIp: event.DestIP.AsSlice(),
RxPackets: event.RxPackets,
TxPackets: event.TxPackets,
RxBytes: event.RxBytes,
TxBytes: event.TxBytes,
SourceResourceId: event.SourceResourceID,
DestResourceId: event.DestResourceID,
FlowId: event.FlowID[:],
RuleId: event.RuleID,
Type: proto.Type(event.Type),
Direction: proto.Direction(event.Direction),
Protocol: uint32(event.Protocol),
SourceIp: event.SourceIP.AsSlice(),
DestIp: event.DestIP.AsSlice(),
RxPackets: event.RxPackets,
TxPackets: event.TxPackets,
RxBytes: event.RxBytes,
TxBytes: event.TxBytes,
},
}
if event.Protocol == nftypes.ICMP {
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
IcmpInfo: &proto.ICMPInfo{

View File

@@ -3,22 +3,18 @@ package store
import (
"sync"
"golang.org/x/exp/maps"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
)
func NewMemoryStore() *Memory {
return &Memory{
events: make(map[uuid.UUID]*types.Event),
events: make(map[string]*types.Event),
}
}
type Memory struct {
mux sync.Mutex
events map[uuid.UUID]*types.Event
events map[string]*types.Event
}
func (m *Memory) StoreEvent(event *types.Event) {
@@ -30,7 +26,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() {
m.mux.Lock()
defer m.mux.Unlock()
maps.Clear(m.events)
m.events = make(map[string]*types.Event)
}
func (m *Memory) GetEvents() []*types.Event {
@@ -43,7 +39,7 @@ func (m *Memory) GetEvents() []*types.Event {
return events
}
func (m *Memory) DeleteEvents(ids []uuid.UUID) {
func (m *Memory) DeleteEvents(ids []string) {
m.mux.Lock()
defer m.mux.Unlock()
for _, id := range ids {

View File

@@ -2,12 +2,11 @@ package types
import (
"net/netip"
"strconv"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
)
type Protocol uint8
@@ -28,10 +27,8 @@ func (p Protocol) String() string {
return "TCP"
case 17:
return "UDP"
case 132:
return "SCTP"
default:
return strconv.FormatUint(uint64(p), 10)
return "unknown"
}
}
@@ -64,29 +61,27 @@ const (
)
type Event struct {
ID uuid.UUID
ID string
Timestamp time.Time
EventFields
}
type EventFields struct {
FlowID uuid.UUID
Type Type
RuleID []byte
Direction Direction
Protocol Protocol
SourceIP netip.Addr
DestIP netip.Addr
SourceResourceID []byte
DestResourceID []byte
SourcePort uint16
DestPort uint16
ICMPType uint8
ICMPCode uint8
RxPackets uint64
TxPackets uint64
RxBytes uint64
TxBytes uint64
FlowID uuid.UUID
Type Type
RuleID []byte
Direction Direction
Protocol Protocol
SourceIP netip.Addr
DestIP netip.Addr
SourcePort uint16
DestPort uint16
ICMPType uint8
ICMPCode uint8
RxPackets uint64
TxPackets uint64
RxBytes uint64
TxBytes uint64
}
type FlowConfig struct {
@@ -113,7 +108,7 @@ type FlowLogger interface {
// GetEvents returns all stored events
GetEvents() []*Event
// DeleteEvents deletes events from the store
DeleteEvents([]uuid.UUID)
DeleteEvents([]string)
// Close closes the logger
Close()
// Enable enables the flow logger receiver
@@ -128,7 +123,7 @@ type Store interface {
// GetEvents returns all stored events
GetEvents() []*Event
// DeleteEvents deletes events from the store
DeleteEvents([]uuid.UUID)
DeleteEvents([]string)
// Close closes the store
Close()
}
@@ -147,5 +142,5 @@ type ConnTracker interface {
type IFaceMapper interface {
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address
Address() device.WGAddress
}

View File

@@ -1,27 +1,12 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
const (
debounceTime = 2 * time.Second
)
var checkChangeFn = checkChange
var ErrStopped = errors.New("monitor has been stopped")
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
@@ -34,99 +19,3 @@ type NetworkMonitor struct {
func New() *NetworkMonitor {
return &NetworkMonitor{}
}
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()
return errors.New("network monitor already started")
}
ctx, nw.cancel = context.WithCancel(ctx)
defer nw.cancel()
nw.wg.Add(1)
nw.mu.Unlock()
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
// debounce changes
timer := time.NewTimer(0)
timer.Stop()
for {
select {
case <-event:
timer.Reset(debounceTime)
case <-timer.C:
return nil
case <-ctx.Done():
timer.Stop()
return ctx.Err()
}
}
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel == nil {
return
}
nw.cancel()
nw.wg.Wait()
}
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event)
return
}
// prevent blocking
select {
case event <- struct{}{}:
default:
}
}
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err)
@@ -28,10 +28,18 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
return ErrStopped
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
@@ -68,11 +76,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
go callback()
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
go callback()
}
}
}

View File

@@ -0,0 +1,82 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available")
}
@@ -31,7 +31,8 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
for {
select {
case <-ctx.Done():
return ctx.Err()
return ErrStopped
// handle route changes
case route := <-routeChan:
// default route and main table
@@ -42,10 +43,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
// triggered on added/replaced routes
case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
}
}

View File

@@ -2,21 +2,10 @@
package networkmonitor
import (
"context"
"fmt"
)
import "context"
type NetworkMonitor struct {
}
// New creates a new network monitor.
func New() *NetworkMonitor {
return &NetworkMonitor{}
}
func (nw *NetworkMonitor) Listen(_ context.Context) error {
return fmt.Errorf("network monitor not supported on mobile platforms")
func (nw *NetworkMonitor) Start(context.Context, func()) error {
return nil
}
func (nw *NetworkMonitor) Stop() {

View File

@@ -1,99 +0,0 @@
package networkmonitor
import (
"context"
"errors"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
type MocMultiEvent struct {
counter int
}
func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if m.counter == 0 {
<-ctx.Done()
return ctx.Err()
}
time.Sleep(1 * time.Second)
m.counter--
return nil
}
func TestNetworkMonitor_Close(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
<-ctx.Done()
return ctx.Err()
}
nw := New()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
time.Sleep(1 * time.Second) // wait for the goroutine to start
nw.Stop()
<-done
if !errors.Is(resErr, context.Canceled) {
t.Errorf("unexpected error: %v", resErr)
}
}
func TestNetworkMonitor_Event(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.Done():
return nil
}
}
nw := New()
defer nw.Stop()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
<-done
if !errors.Is(resErr, nil) {
t.Errorf("unexpected error: %v", nil)
}
}
func TestNetworkMonitor_MultiEvent(t *testing.T) {
eventsRepeated := 3
me := &MocMultiEvent{counter: eventsRepeated}
checkChangeFn = me.checkChange
nw := New()
defer nw.Stop()
done := make(chan struct{})
started := time.Now()
go func() {
if resErr := nw.Listen(context.Background()); resErr != nil {
t.Errorf("unexpected error: %v", resErr)
}
close(done)
}()
<-done
expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime
if time.Since(started) < expectedResponseTime {
t.Errorf("unexpected duration: %v", time.Since(started))
}
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err)
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
for {
select {
case <-ctx.Done():
return ctx.Err()
return ErrStopped
case route := <-routeMonitor.RouteUpdates():
if route.Destination.Bits() != 0 {
continue
}
if routeChanged(route, nexthopv4, nexthopv6) {
return nil
if routeChanged(route, nexthopv4, nexthopv6, callback) {
break
}
}
}
}
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
@@ -51,15 +51,18 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
case systemops.RouteModified:
// TODO: get routing table to figure out if our route is affected for modified routes
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
go callback()
return true
case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
go callback()
return true
}
case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
go callback()
return true
}
}

View File

@@ -442,8 +442,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.isICEActive() {
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
@@ -711,8 +711,8 @@ func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
}
func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}
func (conn *Conn) removeWgPeer() error {

View File

@@ -8,7 +8,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -17,5 +16,4 @@ type WGIface interface {
RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error)
GetProxy() wgproxy.Proxy
Address() wgaddr.Address
}

View File

@@ -1,100 +0,0 @@
package peer
import (
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
type routeIDLookup struct {
localMap sync.Map
remoteMap sync.Map
resolvedIPs sync.Map
}
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
_, exists := r.localMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in local map", resourceID)
}
}
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
r.localMap.Delete(route)
}
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in remote map", resourceID)
}
}
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
r.remoteMap.Delete(route)
}
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
r.resolvedIPs.Store(route.Addr(), resourceID)
}
func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
r.resolvedIPs.Delete(route.Addr())
}
func (r *routeIDLookup) Lookup(src, dst netip.Addr, direction nftypes.Direction) (srcResourceID, dstResourceID string) {
// check resolved ip's first
resId, ok := r.resolvedIPs.Load(src)
if ok {
srcResourceID = resId.(string)
} else {
resId, ok := r.resolvedIPs.Load(dst)
if ok {
dstResourceID = resId.(string)
}
}
switch direction {
case nftypes.Ingress:
if srcResourceID == "" || dstResourceID == "" {
r.localMap.Range(func(key, value interface{}) bool {
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
srcResourceID = value.(string)
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
dstResourceID = value.(string)
}
if srcResourceID != "" && dstResourceID != "" {
return false
}
return true
})
}
case nftypes.Egress:
if srcResourceID == "" || dstResourceID == "" {
r.remoteMap.Range(func(key, value interface{}) bool {
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
srcResourceID = value.(string)
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
dstResourceID = value.(string)
}
if srcResourceID != "" && dstResourceID != "" {
return false
}
return true
})
}
}
return srcResourceID, dstResourceID
}

View File

@@ -17,7 +17,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/ingressgw"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
@@ -177,8 +176,6 @@ type Status struct {
eventQueue *EventQueue
ingressGwMgr *ingressgw.Manager
routeIDLookup routeIDLookup
}
// NewRecorder returns a new Status instance
@@ -314,7 +311,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -326,14 +323,6 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string)
peerState.AddRoute(route)
d.peers[peer] = peerState
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
} else {
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
}
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
@@ -351,28 +340,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
peerState.DeleteRoute(route)
d.peers[peer] = peerState
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
} else {
d.routeIDLookup.RemoveRemoteRouteID(pref)
}
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
// CheckRoutes checks if the source and destination addresses are within the same route
// and returns the resource ID of the route that contains the addresses
func (d *Status) CheckRoutes(src, dst netip.Addr, direction nftypes.Direction) (srcResId string, dstResId string) {
if d == nil {
return
}
return d.routeIDLookup.Lookup(src, dst, direction)
}
func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -586,50 +558,6 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.notifyAddressChanged()
}
// AddLocalPeerStateRoute adds a route to the local peer state
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
}
if d.localPeer.Routes == nil {
d.localPeer.Routes = map[string]struct{}{}
}
d.localPeer.Routes[route] = struct{}{}
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
// RemoveLocalPeerStateRoute removes a route from the local peer state
func (d *Status) RemoveLocalPeerStateRoute(route string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
}
delete(d.localPeer.Routes, route)
d.routeIDLookup.RemoveLocalRouteID(pref)
}
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
func (d *Status) CleanLocalPeerStateRoutes() {
d.mux.Lock()
defer d.mux.Unlock()
d.localPeer.Routes = map[string]struct{}{}
}
// CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() {
d.mux.Lock()
@@ -713,7 +641,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
@@ -722,10 +650,6 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
Prefixes: prefixes,
ParentDomain: originalDomain,
}
for _, prefix := range prefixes {
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
}
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
@@ -736,10 +660,6 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
for k, v := range d.resolvedDomainsStates {
if v.ParentDomain == domain {
delete(d.resolvedDomainsStates, k)
for _, prefix := range v.Prefixes {
d.routeIDLookup.RemoveResolvedIP(prefix)
}
}
}
}

View File

@@ -358,12 +358,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
@@ -371,8 +365,14 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is handled by route exclusion / ip rules
// default route is
if prefix.Bits() == 0 {
continue
}

View File

@@ -330,7 +330,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
c.connectEvent()
}
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}

View File

@@ -160,12 +160,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
client := &dns.Client{
Timeout: 5 * time.Second,
Net: "udp",
@@ -321,7 +315,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",

View File

@@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes, r.route.GetResourceID())
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -3,9 +3,9 @@ package iface
import (
"net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgIfaceBase interface {
@@ -13,7 +13,7 @@ type wgIfaceBase interface {
RemoveAllowedIP(peerKey string, allowedIP string) error
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() device.PacketFilter

View File

@@ -103,7 +103,9 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
delete(m.routes, route.ID)
m.statusRecorder.RemoveLocalPeerStateRoute(route.Network.String())
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
@@ -129,12 +131,18 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
@@ -156,7 +164,9 @@ func (m *serverRouter) cleanUp() {
}
m.statusRecorder.CleanLocalPeerStateRoutes()
state := m.statusRecorder.GetLocalPeerState()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {

View File

@@ -3,7 +3,7 @@ package server
import (
"fmt"
"os"
"path"
"path/filepath"
"syscall"
log "github.com/sirupsen/logrus"
@@ -12,6 +12,7 @@ import (
)
const (
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
stdErrorHandle = ^uintptr(11)
)
@@ -24,10 +25,13 @@ var (
)
func handlePanicLog() error {
// TODO: move this to a central location
logDir := path.Join(os.Getenv("PROGRAMDATA"), "Netbird")
logPath := path.Join(logDir, "netbird.err")
logPath := os.Getenv(windowsPanicLogEnvVar)
if logPath == "" {
return nil
}
// Ensure the directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0750); err != nil {
return fmt.Errorf("create panic log directory: %w", err)
}
@@ -35,11 +39,13 @@ func handlePanicLog() error {
return fmt.Errorf("enforce permission on panic log file: %w", err)
}
// Open log file with append mode
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("open panic log file: %w", err)
}
// Redirect stderr to the file
if err = redirectStderr(f); err != nil {
if closeErr := f.Close(); closeErr != nil {
log.Warnf("failed to close file after redirect error: %v", closeErr)
@@ -53,6 +59,7 @@ func handlePanicLog() error {
// redirectStderr redirects stderr to the provided file
func redirectStderr(f *os.File) error {
// Get the current process's stderr handle
if err := setStdHandle(f); err != nil {
return fmt.Errorf("failed to set stderr handle: %w", err)
}

View File

@@ -160,7 +160,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
runningChan chan struct{},
runningChan chan error,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -628,21 +628,20 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
runningChan := make(chan error)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
for {
select {
case <-runningChan:
return &proto.UpResponse{}, nil
case err := <-runningChan:
if err != nil {
log.Debugf("waiting for engine to become ready failed: %s", err)
} else {
return &proto.UpResponse{}, nil
}
case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
}
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -42,21 +41,11 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
srcIP = engine.GetWgAddr()
}
srcAddr, ok := netip.AddrFromSlice(srcIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
dstIP := net.ParseIP(req.GetDestinationIp())
if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr()
}
dstAddr, ok := netip.AddrFromSlice(dstIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address")
}
@@ -96,8 +85,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
}
builder := &uspfilter.PacketBuilder{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()),

View File

@@ -4,10 +4,8 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
@@ -27,99 +25,95 @@ type GRPCClient struct {
realClient proto.FlowServiceClient
clientConn *grpc.ClientConn
stream proto.FlowService_EventsClient
streamMu sync.Mutex
}
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
var opts []grpc.DialOption
func NewClient(ctx context.Context, addr, payload, signature string) (*GRPCClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if strings.Contains(addr, "443") {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool,
})))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}))
}
opts = append(opts,
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithIdleTimeout(interval*2),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
withAuthToken(payload, signature),
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
)
conn, err := grpc.NewClient(addr, opts...)
if err != nil {
return nil, fmt.Errorf("creating new grpc client: %w", err)
return nil, fmt.Errorf("dialing with context: %s", err)
}
return &GRPCClient{
client := &GRPCClient{
realClient: proto.NewFlowServiceClient(conn),
clientConn: conn,
}, nil
}
return client, nil
}
func (c *GRPCClient) Close() error {
c.streamMu.Lock()
defer c.streamMu.Unlock()
c.stream = nil
return c.clientConn.Close()
}
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
backOff := defaultBackoff(ctx, interval)
func (c *GRPCClient) Receive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
backOff := defaultBackoff(ctx)
operation := func() error {
return c.establishStreamAndReceive(ctx, msgHandler)
connState := c.clientConn.GetState()
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
}
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
return err
}
c.stream = stream
err = checkHeader(stream)
if err != nil {
return err
}
return c.receive(stream, msgHandler)
}
if err := backoff.Retry(operation, backOff); err != nil {
return fmt.Errorf("receive failed permanently: %w", err)
err := backoff.Retry(operation, backOff)
if err != nil {
log.Errorf("exiting the flow receiver service connection retry loop due to the unrecoverable error: %v", err)
return err
}
return nil
}
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
if c.clientConn.GetState() == connectivity.Shutdown {
return backoff.Permanent(errors.New("connection to flow receiver has been shut down"))
}
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
return fmt.Errorf("create event stream: %w", err)
}
if err = checkHeader(stream); err != nil {
return fmt.Errorf("check header: %w", err)
}
c.streamMu.Lock()
c.stream = stream
c.streamMu.Unlock()
return c.receive(stream, msgHandler)
}
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
for {
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive from stream: %w", err)
return err
}
if err := msgHandler(msg); err != nil {
return fmt.Errorf("handle message: %w", err)
return err
}
}
}
@@ -128,7 +122,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
header, err := stream.Header()
if err != nil {
log.Errorf("waiting for flow receiver header: %s", err)
return fmt.Errorf("wait for header: %w", err)
return err
}
if len(header) == 0 {
@@ -138,29 +132,26 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
return nil
}
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: interval / 2,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.streamMu.Lock()
stream := c.stream
c.streamMu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
func (c *GRPCClient) Send(ctx context.Context, event *proto.FlowEvent) error {
if c.stream == nil {
return fmt.Errorf("stream not initialized")
}
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
err := c.stream.Send(event)
if err != nil {
return fmt.Errorf("sending flow event: %s", err)
}
return nil

View File

@@ -130,7 +130,7 @@ type FlowEvent struct {
unknownFields protoimpl.UnknownFields
// Unique client event identifier
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
EventId string `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
// When the event occurred
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// Public key of the sending peer
@@ -170,11 +170,11 @@ func (*FlowEvent) Descriptor() ([]byte, []int) {
return file_flow_proto_rawDescGZIP(), []int{0}
}
func (x *FlowEvent) GetEventId() []byte {
func (x *FlowEvent) GetEventId() string {
if x != nil {
return x.EventId
}
return nil
return ""
}
func (x *FlowEvent) GetTimestamp() *timestamppb.Timestamp {
@@ -204,7 +204,7 @@ type FlowEventAck struct {
unknownFields protoimpl.UnknownFields
// Unique client event identifier that has been ack'ed
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
EventId string `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
}
func (x *FlowEventAck) Reset() {
@@ -239,11 +239,11 @@ func (*FlowEventAck) Descriptor() ([]byte, []int) {
return file_flow_proto_rawDescGZIP(), []int{1}
}
func (x *FlowEventAck) GetEventId() []byte {
func (x *FlowEventAck) GetEventId() string {
if x != nil {
return x.EventId
}
return nil
return ""
}
type FlowFields struct {
@@ -278,9 +278,6 @@ type FlowFields struct {
// Number of bytes
RxBytes uint64 `protobuf:"varint,12,opt,name=rx_bytes,json=rxBytes,proto3" json:"rx_bytes,omitempty"`
TxBytes uint64 `protobuf:"varint,13,opt,name=tx_bytes,json=txBytes,proto3" json:"tx_bytes,omitempty"`
// Resource ID
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
}
func (x *FlowFields) Reset() {
@@ -413,20 +410,6 @@ func (x *FlowFields) GetTxBytes() uint64 {
return 0
}
func (x *FlowFields) GetSourceResourceId() []byte {
if x != nil {
return x.SourceResourceId
}
return nil
}
func (x *FlowFields) GetDestResourceId() []byte {
if x != nil {
return x.DestResourceId
}
return nil
}
type isFlowFields_ConnectionInfo interface {
isFlowFields_ConnectionInfo()
}
@@ -565,7 +548,7 @@ var file_flow_proto_rawDesc = []byte{
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0xb2, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
0x01, 0x28, 0x09, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75,
0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d,
@@ -576,8 +559,8 @@ var file_flow_proto_rawDesc = []byte{
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x22, 0x29, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77,
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e,
0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e,
0x74, 0x49, 0x64, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c,
0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e,
0x74, 0x49, 0x64, 0x22, 0xc4, 0x03, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c,
0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74,
0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77,
@@ -604,36 +587,31 @@ var file_flow_proto_rawDesc = []byte{
0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20,
0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08,
0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07,
0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x6f, 0x75, 0x72, 0x63,
0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0e, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x65, 0x73, 0x6f, 0x75,
0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65,
0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x42,
0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e,
0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f,
0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12,
0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01,
0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08,
0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d,
0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f,
0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f,
0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59,
0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a,
0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08,
0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59,
0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72,
0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54,
0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a,
0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47,
0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12,
0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74,
0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f,
0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65,
0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75,
0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f,
0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74,
0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f,
0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a,
0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d,
0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79,
0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f,
0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41,
0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44,
0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10,
0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15,
0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e,
0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53,
0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42,
0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a,
0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46,
0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e,
0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01,
0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
}
var (

View File

@@ -13,7 +13,7 @@ service FlowService {
message FlowEvent {
// Unique client event identifier
bytes event_id = 1;
string event_id = 1;
// When the event occurred
google.protobuf.Timestamp timestamp = 2;
@@ -26,7 +26,7 @@ message FlowEvent {
message FlowEventAck {
// Unique client event identifier that has been ack'ed
bytes event_id = 1;
string event_id = 1;
}
message FlowFields {
@@ -67,11 +67,6 @@ message FlowFields {
// Number of bytes
uint64 rx_bytes = 12;
uint64 tx_bytes = 13;
// Resource ID
bytes source_resource_id = 14;
bytes dest_resource_id = 15;
}
// Flow event types

83
formatter/formatter.go Normal file
View File

@@ -0,0 +1,83 @@
package formatter
import (
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
)
// TextFormatter formats logs into text with included source code's path
type TextFormatter struct {
timestampFormat string
levelDesc []string
}
// SyslogFormatter formats logs into text
type SyslogFormatter struct {
levelDesc []string
}
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
// NewTextFormatter create new MyTextFormatter instance
func NewTextFormatter() *TextFormatter {
return &TextFormatter{
levelDesc: validLevelDesc,
timestampFormat: time.RFC3339, // or RFC3339
}
}
// NewSyslogFormatter create new MySyslogFormatter instance
func NewSyslogFormatter() *SyslogFormatter {
return &SyslogFormatter{
levelDesc: validLevelDesc,
}
}
// Format renders a single log entry
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
keys := make([]string, 0, len(entry.Data))
for k, v := range entry.Data {
if k == "source" {
continue
}
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
}
if len(keys) > 0 {
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
}
level := f.parseLevel(entry.Level)
return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.timestampFormat), level, fields, entry.Data["source"], entry.Message)), nil
}
func (f *TextFormatter) parseLevel(level logrus.Level) string {
if len(f.levelDesc) < int(level) {
return ""
}
return f.levelDesc[level]
}
// Format renders a single log entry
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
keys := make([]string, 0, len(entry.Data))
for k, v := range entry.Data {
if k == "source" {
continue
}
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
}
if len(keys) > 0 {
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
}
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
}

View File

@@ -1,4 +1,4 @@
package txt
package formatter
import (
"testing"
@@ -24,3 +24,20 @@ func TestLogTextFormat(t *testing.T) {
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}
func TestLogSyslogFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
Level: 3,
Message: "Some Message",
}
formatter := NewSyslogFormatter()
result, _ := formatter.Format(someEntry)
parsedString := string(result)
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}

View File

@@ -1,4 +1,4 @@
package hook
package formatter
import (
"fmt"
@@ -41,8 +41,7 @@ func (hook ContextHook) Levels() []logrus.Level {
// Fire extend with the source information the entry.Data
func (hook ContextHook) Fire(entry *logrus.Entry) error {
src := hook.parseSrc(entry.Caller.File)
entry.Data[EntryKeySource] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
additionalEntries(entry)
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
if entry.Context == nil {
return nil

View File

@@ -1,9 +0,0 @@
//go:build !loggoroutine
package hook
import log "github.com/sirupsen/logrus"
func additionalEntries(_ *log.Entry) {
// This function is empty and is used to demonstrate the use of additional hooks.
}

Some files were not shown because too many files have changed in this diff Show More