mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Compare commits
6 Commits
test/netwo
...
feature/ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b02e9c3a8 | ||
|
|
eee90fbbbf | ||
|
|
85aea0a030 | ||
|
|
e41acdb9ac | ||
|
|
54dc27abec | ||
|
|
fdd7dc67c0 |
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,22 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**Is any other VPN software installed?**
|
||||
**NetBird status -dA output:**
|
||||
|
||||
If yes, which one?
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
|
||||
**Debug output**
|
||||
**Do you face any (non-mobile) client issues?**
|
||||
|
||||
To help us resolve the problem, please attach the following debug output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
As well as the file created by
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
We advise reviewing the anonymized output for any remaining personal information.
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -55,10 +47,3 @@ If applicable, add screenshots to help explain your problem.
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -258,7 +258,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -325,8 +325,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -392,7 +392,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -461,7 +461,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
|
||||
@@ -134,11 +134,10 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{}, 1)
|
||||
clientErr := make(chan error, 1)
|
||||
run := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
clientErr <- err
|
||||
run <- err
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -148,9 +147,13 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
case err := <-run:
|
||||
if err != nil {
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
||||
}
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ type Manager struct {
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
|
||||
@@ -10,15 +10,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
|
||||
@@ -4,20 +4,21 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ const (
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
|
||||
@@ -16,15 +16,15 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
|
||||
@@ -3,20 +3,21 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -17,23 +17,26 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package uspfilter
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -22,13 +21,13 @@ const (
|
||||
firewallRuleName = "Netbird"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close closes the firewall manager
|
||||
func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -45,8 +44,8 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
||||
@@ -3,14 +3,14 @@ package common
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -13,15 +14,13 @@ import (
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
FlowId uuid.UUID
|
||||
Direction nftypes.Direction
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
lastSeen atomic.Int64
|
||||
PacketsTx atomic.Uint64
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
FlowId uuid.UUID
|
||||
Direction nftypes.Direction
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@@ -31,17 +30,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// UpdateCounters safely updates the packet and byte counters
|
||||
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||
if direction == nftypes.Egress {
|
||||
b.PacketsTx.Add(1)
|
||||
b.BytesTx.Add(uint64(bytes))
|
||||
} else {
|
||||
b.PacketsRx.Add(1)
|
||||
b.BytesRx.Add(uint64(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
@@ -64,3 +52,16 @@ type ConnKey struct {
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
// makeConnKey creates a connection key
|
||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
|
||||
return ConnKey{
|
||||
SrcIP: srcAddr,
|
||||
DstIP: dstAddr,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
@@ -21,22 +21,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -46,22 +46,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -23,13 +23,14 @@ const (
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
type ICMPConnKey struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
ID uint16
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
Sequence uint16
|
||||
ID uint16
|
||||
}
|
||||
|
||||
func (i ICMPConnKey) String() string {
|
||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence)
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
@@ -45,8 +46,8 @@ type ICMPTracker struct {
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -56,27 +57,21 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||
key := ICMPConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
ID: id,
|
||||
}
|
||||
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -84,7 +79,6 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
||||
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
return key, true
|
||||
}
|
||||
@@ -93,21 +87,22 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
|
||||
// TODO: icmp doesn't need to extend the timeout
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
@@ -117,7 +112,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
t.sendStartEvent(direction, key, typ, code)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -125,8 +120,8 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
},
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
@@ -138,20 +133,16 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := ICMPConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
ID: id,
|
||||
}
|
||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -162,19 +153,16 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.tickerCancel()
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-ctx.Done():
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -188,58 +176,56 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
ICMPCode: conn.ICMPCode,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||
fields := nftypes.EventFields{
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
}
|
||||
if direction == nftypes.Ingress {
|
||||
fields.RxPackets = 1
|
||||
fields.RxBytes = uint64(size)
|
||||
} else {
|
||||
fields.TxPackets = 1
|
||||
fields.TxBytes = uint64(size)
|
||||
}
|
||||
t.flowLogger.StoreEvent(fields)
|
||||
})
|
||||
}
|
||||
|
||||
// makeICMPKey creates an ICMP connection key
|
||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
return ICMPConnKey{
|
||||
SrcIP: srcAddr,
|
||||
DstIP: dstAddr,
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -10,12 +10,12 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -23,17 +23,17 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,8 +3,7 @@ package conntrack
|
||||
// TODO: Send RST packets for invalid/timed-out connections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -89,8 +88,6 @@ const (
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
tombstone atomic.Bool
|
||||
@@ -123,7 +120,7 @@ type TCPTracker struct {
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
done chan struct{}
|
||||
timeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
@@ -134,28 +131,21 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
timeout = DefaultTCPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
timeout: timeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -164,10 +154,9 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
if exists {
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||
conn.UpdateLastSeen()
|
||||
conn.Unlock()
|
||||
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
return key, true
|
||||
}
|
||||
|
||||
@@ -175,36 +164,37 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
conn.tombstone.Store(false)
|
||||
|
||||
@@ -215,17 +205,12 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -246,15 +231,15 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
@@ -264,8 +249,6 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
state := conn.State
|
||||
defer func() {
|
||||
if state != conn.State {
|
||||
@@ -304,24 +287,17 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
conn.State = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
} else if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
// Simultaneous close - both sides sent FIN
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPRst != 0:
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
@@ -329,7 +305,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
conn.State = TCPStateTimeWait
|
||||
|
||||
t.logger.Trace("TCP connection %s completed", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
@@ -338,7 +314,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
@@ -352,7 +328,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
conn.SetTombstone()
|
||||
|
||||
// Send close event for gracefully closed connections
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||
}
|
||||
}
|
||||
@@ -399,14 +375,12 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-ctx.Done():
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -437,11 +411,11 @@ func (t *TCPTracker) cleanup() {
|
||||
// Return IPs to pool
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
|
||||
|
||||
// event already handled by state change
|
||||
if conn.State != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -449,7 +423,8 @@ func (t *TCPTracker) cleanup() {
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
// Clean up all remaining IPs
|
||||
t.mutex.Lock()
|
||||
@@ -471,20 +446,15 @@ func isValidFlagCombination(flags uint8) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourcePort: key.SrcPort,
|
||||
DestPort: key.DstPort,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||
})
|
||||
}
|
||||
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Send initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
|
||||
// Receive SYN-ACK
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
// Send ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
|
||||
// Test data transfer
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
require.True(t, valid, "Data should be allowed after handshake")
|
||||
},
|
||||
},
|
||||
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
require.True(t, valid, "ACK for FIN should be allowed")
|
||||
|
||||
// Receive FIN from other side
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
require.True(t, valid, "FIN should be allowed")
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Both sides send FIN+ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||
|
||||
// Both sides send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
require.True(t, valid, "Final ACKs should be allowed")
|
||||
},
|
||||
},
|
||||
@@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST in established",
|
||||
setupState: func() {
|
||||
// Establish connection first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
},
|
||||
wantValid: true,
|
||||
desc: "Should accept RST for established connection",
|
||||
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST without connection",
|
||||
setupState: func() {},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
},
|
||||
wantValid: false,
|
||||
desc: "Should reject RST without connection",
|
||||
@@ -208,12 +208,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
tt.sendRST()
|
||||
|
||||
// Verify connection state is as expected
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
@@ -225,15 +220,15 @@ func TestRSTHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
@@ -241,12 +236,12 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -254,17 +249,17 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -272,16 +267,16 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
i++
|
||||
}
|
||||
@@ -296,10 +291,10 @@ func BenchmarkCleanup(b *testing.B) {
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,8 +21,6 @@ const (
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
type UDPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
}
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
@@ -32,8 +29,8 @@ type UDPTracker struct {
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -43,41 +40,34 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -85,7 +75,6 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
@@ -93,21 +82,21 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
@@ -116,17 +105,12 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -137,20 +121,17 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
func (t *UDPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-ctx.Done():
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -164,16 +145,16 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Trace("Removed UDP connection %s (timeout)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *UDPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections = nil
|
||||
@@ -181,16 +162,11 @@ func (t *UDPTracker) Close() {
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
@@ -200,20 +176,15 @@ func (t *UDPTracker) Timeout() time.Duration {
|
||||
return t.timeout
|
||||
}
|
||||
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourcePort: key.SrcPort,
|
||||
DestPort: key.DstPort,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -35,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
assert.NotNil(t, tracker.cleanupTicker)
|
||||
assert.NotNil(t, tracker.tickerCancel)
|
||||
assert.NotNil(t, tracker.done)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -49,15 +49,10 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
||||
conn, exists := tracker.connections[key]
|
||||
require.True(t, exists)
|
||||
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||
@@ -71,18 +66,18 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Track outbound connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
sleep time.Duration
|
||||
@@ -99,7 +94,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid source IP",
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
@@ -109,7 +104,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
{
|
||||
name: "invalid destination IP",
|
||||
srcIP: dstIP,
|
||||
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.4"),
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
@@ -149,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
if tt.sleep > 0 {
|
||||
time.Sleep(tt.sleep)
|
||||
}
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@@ -160,45 +155,42 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
timeout := 50 * time.Millisecond
|
||||
cleanupInterval := 25 * time.Millisecond
|
||||
|
||||
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||
defer tickerCancel()
|
||||
|
||||
// Create tracker with custom cleanup interval
|
||||
tracker := &UDPTracker{
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
tickerCancel: tickerCancel,
|
||||
done: make(chan struct{}),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
|
||||
// Add some connections
|
||||
connections := []struct {
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{
|
||||
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||
srcIP: net.ParseIP("192.168.1.2"),
|
||||
dstIP: net.ParseIP("192.168.1.3"),
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
},
|
||||
{
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.5"),
|
||||
srcPort: 12346,
|
||||
dstPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
@@ -223,12 +215,12 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -236,17 +228,17 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,16 +15,13 @@ import (
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
flowID := uuid.New()
|
||||
|
||||
// Extract ICMP header to get type and code
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
@@ -36,6 +33,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
@@ -43,15 +42,30 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
@@ -59,20 +73,21 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
@@ -81,7 +96,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@@ -102,11 +117,13 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
@@ -117,11 +134,9 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
// TODO: get packets/bytes
|
||||
})
|
||||
}
|
||||
|
||||
@@ -22,14 +22,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
}
|
||||
}()
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
@@ -58,7 +51,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
@@ -74,7 +66,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
}
|
||||
ep.Close()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
@@ -106,27 +98,17 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
fields := nftypes.EventFields{
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
Protocol: 6,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
SourcePort: id.LocalPort,
|
||||
DestPort: id.RemotePort,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -89,6 +89,21 @@ func (f *udpForwarder) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: 17,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
SourcePort: id.LocalPort,
|
||||
DestPort: id.RemotePort,
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle UDP connections
|
||||
func (f *udpForwarder) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
@@ -125,6 +140,8 @@ func (f *udpForwarder) cleanup() {
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -148,19 +165,13 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
}
|
||||
}()
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
@@ -173,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -200,14 +212,13 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
@@ -227,7 +238,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
@@ -253,30 +264,19 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
}
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
fields := nftypes.EventFields{
|
||||
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
Protocol: 17, // UDP protocol number
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||
}
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
SourcePort: id.LocalPort,
|
||||
DestPort: id.RemotePort,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
|
||||
@@ -3,7 +3,6 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -32,9 +31,13 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
@@ -119,12 +122,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ip.Is4() {
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -2,91 +2,90 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr wgaddr.Address
|
||||
testIP netip.Addr
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: wgaddr.Address{
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: wgaddr.Address{
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: wgaddr.Address{
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: wgaddr.Address{
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
@@ -96,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
@@ -175,7 +174,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
type PeerRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
|
||||
@@ -2,7 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP netip.Addr
|
||||
DestinationIP netip.Addr
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,30 +260,28 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter.Load() {
|
||||
if m.nativeRouter {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
@@ -311,46 +309,39 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
|
||||
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
strRuleId := "<no id>"
|
||||
strRuleId := "implicit"
|
||||
if ruleId != nil {
|
||||
strRuleId = string(ruleId)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||
if blocked {
|
||||
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||
trace.AddResult(StagePeerACL, msg, false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StagePeerACL, msg, true)
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
|
||||
// Handle netstack mode
|
||||
if m.netstack {
|
||||
switch {
|
||||
case !m.localForwarding:
|
||||
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||
case m.forwarder.Load() != nil:
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
default:
|
||||
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||
}
|
||||
return true
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
}
|
||||
|
||||
// In normal mode, packets are allowed through for local delivery
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled.Load() {
|
||||
if !m.routingEnabled {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
@@ -366,14 +357,14 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
strId = "<no id>"
|
||||
strId = "implicit"
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||
@@ -382,7 +373,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder.Load() != nil {
|
||||
if allowed && m.forwarder != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
@@ -401,7 +392,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
@@ -1,440 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||
t.Logf("Trace results: %v", trace.Results)
|
||||
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||
for _, result := range trace.Results {
|
||||
actualStages = append(actualStages, result.Stage)
|
||||
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||
}
|
||||
|
||||
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||
lastResult := trace.Results[len(trace.Results)-1]
|
||||
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||
}
|
||||
|
||||
func TestTracePacket(t *testing.T) {
|
||||
setupTracerTest := func(statefulMode bool) *Manager {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
m.stateful = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||
builder := &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: protocol,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Direction: direction,
|
||||
}
|
||||
|
||||
if protocol == "tcp" {
|
||||
builder.TCPState = &TCPState{SYN: true}
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||
return &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: "icmp",
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Direction: direction,
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(*Manager)
|
||||
packetBuilder func() *PacketBuilder
|
||||
expectedStages []PacketStage
|
||||
expectedAllow bool
|
||||
}{
|
||||
{
|
||||
name: "LocalTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = true
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithoutForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_NativeRouter",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_RoutingDisabled",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(false)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "ConnectionTracking_Hit",
|
||||
setup: func(m *Manager) {
|
||||
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||
return pb
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "OutboundTraffic",
|
||||
setup: func(m *Manager) {
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPEchoRequest",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPDestinationUnreachable",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithoutHook",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "StatefulDisabled_NoTracking",
|
||||
setup: func(m *Manager) {
|
||||
m.stateful = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
m := setupTracerTest(true)
|
||||
|
||||
tc.setup(m)
|
||||
|
||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||
"100.10.0.100 should be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||
"172.17.0.2 should not be recognized as a local IP")
|
||||
|
||||
pb := tc.packetBuilder()
|
||||
|
||||
trace, err := m.TracePacketFromBuilder(pb)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyTraceStages(t, trace, tc.expectedStages)
|
||||
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -67,9 +66,9 @@ func (r RouteRules) Sort() {
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
outgoingRules map[string]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
@@ -81,9 +80,9 @@ type Manager struct {
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled atomic.Bool
|
||||
routingEnabled bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter atomic.Bool
|
||||
nativeRouter bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
@@ -96,7 +95,7 @@ type Manager struct {
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
forwarder *forwarder.Forwarder
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
@@ -169,18 +168,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@@ -212,7 +211,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder.Load() == nil {
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
@@ -256,20 +255,20 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
@@ -277,19 +276,19 @@ func (m *Manager) determineRouting() error {
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
@@ -298,24 +297,24 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder.Load() != nil {
|
||||
if m.forwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled.Store(false)
|
||||
m.routingEnabled = false
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled.Store(false)
|
||||
m.routingEnabled = false
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder.Store(forwarder)
|
||||
m.forwarder = forwarder
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
@@ -331,7 +330,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
|
||||
@@ -342,7 +341,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return nil
|
||||
@@ -361,23 +360,17 @@ func (m *Manager) AddPeerFiltering(
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
mgmtId: id,
|
||||
ip: i,
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
}
|
||||
if i.Is4() {
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
r.ip = ipNormalized
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
@@ -402,10 +395,10 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(RuleSet)
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
@@ -419,10 +412,13 @@ func (m *Manager) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
@@ -436,16 +432,14 @@ func (m *Manager) AddRouteFiltering(
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -474,10 +468,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip], r.id)
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -510,13 +504,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.processOutgoingHooks(packetData)
|
||||
}
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||
return m.dropFilter(packetData, size)
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
@@ -524,7 +518,10 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -537,34 +534,31 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
if srcIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
return d.ip4.SrcIP, d.ip4.DstIP
|
||||
case layers.LayerTypeIPv6:
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
return d.ip6.SrcIP, d.ip6.DstIP
|
||||
default:
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -591,70 +585,51 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
|
||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode)
|
||||
}
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -663,19 +638,19 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
if srcIP == nil {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
@@ -683,28 +658,27 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
if blocked {
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
RuleID: ruleId,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourceIP: srcAddr,
|
||||
DestIP: dstAddr,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
RxPackets: 1,
|
||||
RxBytes: uint64(size),
|
||||
})
|
||||
return true
|
||||
}
|
||||
@@ -715,7 +689,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
}
|
||||
|
||||
// track inbound packets to get the correct direction and session id for flows
|
||||
m.trackInbound(d, srcIP, dstIP, ruleID, size)
|
||||
m.trackInbound(d, srcIP, dstIP)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -726,12 +700,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.forwarder.Load() == nil {
|
||||
if m.forwarder == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
|
||||
@@ -741,34 +715,37 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
if !m.routingEnabled {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter.Load() {
|
||||
if m.nativeRouter {
|
||||
return false
|
||||
}
|
||||
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
id, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
RuleID: id,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourceIP: srcAddr,
|
||||
DestIP: dstAddr,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
@@ -777,7 +754,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
@@ -822,7 +799,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return m.tcpTracker.IsValidInbound(
|
||||
@@ -831,7 +808,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
getTCPFlags(&d.tcp),
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeUDP:
|
||||
@@ -840,7 +816,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeICMPv4:
|
||||
@@ -848,8 +823,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
d.icmp4.TypeCode.Type(),
|
||||
size,
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
@@ -869,22 +844,20 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) {
|
||||
if m.isSpecialICMP(d) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
@@ -909,10 +882,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -946,13 +919,16 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
@@ -996,7 +972,9 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
@@ -1006,22 +984,23 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
|
||||
udpHook: hook,
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
if ip.To4() != nil {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@@ -1069,21 +1048,20 @@ func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
fwder := m.forwarder.Load()
|
||||
if fwder == nil {
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
fwder.Stop()
|
||||
m.forwarder.Store(nil)
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
|
||||
@@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
manager.processOutgoingHooks(testOut)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, 0)
|
||||
manager.dropFilter(testIn)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.processOutgoingHooks(syn)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.dropFilter(synack)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.processOutgoingHooks(syn)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.processOutgoingHooks(syn)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
@@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
@@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
require.True(tb, manager.routingEnabled)
|
||||
require.False(tb, manager.nativeRouter)
|
||||
|
||||
tb.Cleanup(func() {
|
||||
require.NoError(tb, manager.Close(nil))
|
||||
@@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
@@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
|
||||
for i, p := range tc.packets {
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
srcIP := net.ParseIP(p.srcIP)
|
||||
dstIP := net.ParseIP(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
|
||||
@@ -18,17 +18,17 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() iface.WGAddress
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
@@ -54,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() wgaddr.Address {
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc == nil {
|
||||
return wgaddr.Address{}
|
||||
return iface.WGAddress{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ip := netip.MustParseAddr("192.168.1.1")
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
ip net.IP
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
@@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
ip: net.IPv4(10, 168, 0, 1),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
@@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
ip: net.IPv6loopback,
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
@@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
@@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
@@ -269,8 +269,8 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
net.ParseIP("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
hookCalled = true
|
||||
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
result := manager.processOutgoingHooks(buf.Bytes())
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
result = manager.processOutgoingHooks(buf.Bytes())
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -569,11 +569,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
||||
|
||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||
@@ -636,12 +636,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
if cp.shouldAllow {
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
||||
require.True(t, exists, "Connection should still exist during valid window")
|
||||
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||
"LastSeen should be updated for valid responses")
|
||||
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
@@ -13,8 +14,6 @@ import (
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -53,10 +52,9 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
@@ -66,7 +64,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -111,17 +108,35 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// force IPv4
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeIP] = conn
|
||||
b.endpoints[fakeAddr] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
|
||||
return fakeUDPAddr, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
log.Warnf("failed to convert IP to netip.Addr")
|
||||
return
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
defer b.endpointsMu.Unlock()
|
||||
|
||||
delete(b.endpoints, fakeIP)
|
||||
delete(b.endpoints, fakeAddr)
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
@@ -146,10 +161,9 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
@@ -261,6 +275,21 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
newAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||
Port: peerAddress.Port,
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
@@ -17,8 +17,6 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
@@ -43,7 +41,6 @@ type UniversalUDPMuxParams struct {
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
}
|
||||
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
@@ -67,7 +64,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
@@ -122,7 +118,6 @@ type udpConn struct {
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
@@ -164,11 +159,6 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
|
||||
@@ -9,14 +9,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
package wgaddr
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
// WGAddress WireGuard parsed address
|
||||
type WGAddress struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (Address, error) {
|
||||
func ParseWGAddress(address string) (WGAddress, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
if err != nil {
|
||||
return Address{}, err
|
||||
return WGAddress{}, err
|
||||
}
|
||||
return Address{
|
||||
return WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
func (addr WGAddress) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
}
|
||||
@@ -13,12 +13,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
type WGTunDevice struct {
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -32,7 +31,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -94,7 +93,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
@@ -124,7 +123,7 @@ func (t *WGTunDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) WgAddress() wgaddr.Address {
|
||||
func (t *WGTunDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -30,7 +29,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -86,7 +85,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -107,7 +106,7 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -11,16 +10,16 @@ import (
|
||||
// PacketFilter interface for firewall abilities
|
||||
type PacketFilter interface {
|
||||
// DropOutgoing filter outgoing packets from host to external destinations
|
||||
DropOutgoing(packetData []byte, size int) bool
|
||||
DropOutgoing(packetData []byte) bool
|
||||
|
||||
// DropIncoming filter incoming packets from external sources to host
|
||||
DropIncoming(packetData []byte, size int) bool
|
||||
DropIncoming(packetData []byte) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
@@ -58,7 +57,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
@@ -82,7 +81,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||
if !filter.DropIncoming(buf[offset:]) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
return 1, nil
|
||||
})
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
|
||||
@@ -14,12 +14,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
port int
|
||||
key string
|
||||
iceBind *bind.ICEBind
|
||||
@@ -31,7 +30,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -121,11 +120,11 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,13 +14,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
wgPort int
|
||||
key string
|
||||
mtu int
|
||||
@@ -35,7 +34,7 @@ type TunKernelDevice struct {
|
||||
filterFn bind.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &TunKernelDevice{
|
||||
ctx: ctx,
|
||||
@@ -100,10 +99,9 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: rawSock,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
UDPConn: rawSock,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
}
|
||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||
go mux.ReadFromConn(t.ctx)
|
||||
@@ -114,7 +112,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return t.udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -147,7 +145,7 @@ func (t *TunKernelDevice) Close() error {
|
||||
return closErr
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
|
||||
func (t *TunKernelDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,13 +13,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -35,7 +34,7 @@ type TunNetstackDevice struct {
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
return &TunNetstackDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -98,7 +97,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
|
||||
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -117,7 +116,7 @@ func (t *TunNetstackDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
|
||||
func (t *TunNetstackDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -12,12 +12,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type USPDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -29,7 +28,7 @@ type USPDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
return &USPDevice{
|
||||
@@ -94,7 +93,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
func (t *USPDevice) UpdateAddr(address WGAddress) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -114,7 +113,7 @@ func (t *USPDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||
func (t *USPDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,14 +13,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
address WGAddress
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -33,7 +32,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -119,7 +118,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -140,7 +139,7 @@ func (t *TunDevice) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/freebsd"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
@@ -57,7 +56,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
link, err := freebsd.LinkByName(l.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
@@ -92,7 +90,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(l, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,14 +7,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
@@ -29,6 +28,8 @@ const (
|
||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||
)
|
||||
|
||||
type WGAddress = device.WGAddress
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
Free() error
|
||||
@@ -71,7 +72,7 @@ func (w *WGIface) Name() string {
|
||||
}
|
||||
|
||||
// Address returns the interface address
|
||||
func (w *WGIface) Address() wgaddr.Address {
|
||||
func (w *WGIface) Address() device.WGAddress {
|
||||
return w.tun.WgAddress()
|
||||
}
|
||||
|
||||
@@ -102,7 +103,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||
addr, err := device.ParseWGAddress(newAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,18 +3,17 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -6,18 +6,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -5,18 +5,17 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -8,13 +8,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -22,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
@@ -35,7 +34,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
return wgIFace, nil
|
||||
}
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
|
||||
@@ -4,17 +4,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -6,7 +6,6 @@ package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -36,7 +35,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
@@ -50,31 +49,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
}
|
||||
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||
}
|
||||
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -17,13 +16,13 @@ import (
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
|
||||
fakeNetIP *netip.AddrPort
|
||||
wgBindEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
wgAddr *net.UDPAddr
|
||||
wgEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
@@ -34,24 +33,20 @@ type ProxyBind struct {
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
fakeNetIP, err := fakeAddress(nbAddr)
|
||||
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.fakeNetIP = fakeNetIP
|
||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.wgAddr = addr
|
||||
p.wgEndpoint = addrToEndpoint(addr)
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return nil
|
||||
return err
|
||||
|
||||
}
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||
Port: int(p.fakeNetIP.Port()),
|
||||
Zone: p.fakeNetIP.Addr().Zone(),
|
||||
}
|
||||
return p.wgAddr
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Work() {
|
||||
@@ -59,8 +54,6 @@ func (p *ProxyBind) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
@@ -100,7 +93,7 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
@@ -133,7 +126,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgBindEndpoint,
|
||||
Endpoint: p.wgEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
@@ -141,19 +134,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse new IP: %w", err)
|
||||
}
|
||||
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@
|
||||
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
||||
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
||||
|
||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||
|
||||
Unicode True
|
||||
|
||||
######################################################################
|
||||
@@ -70,9 +68,6 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_PAGE_DIRECTORY
|
||||
|
||||
; Custom page for autostart checkbox
|
||||
Page custom AutostartPage AutostartPageLeave
|
||||
|
||||
!insertmacro MUI_PAGE_INSTFILES
|
||||
|
||||
!insertmacro MUI_PAGE_FINISH
|
||||
@@ -85,36 +80,8 @@ Page custom AutostartPage AutostartPageLeave
|
||||
|
||||
!insertmacro MUI_LANGUAGE "English"
|
||||
|
||||
; Variables for autostart option
|
||||
Var AutostartCheckbox
|
||||
Var AutostartEnabled
|
||||
|
||||
######################################################################
|
||||
|
||||
; Function to create the autostart options page
|
||||
Function AutostartPage
|
||||
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
|
||||
|
||||
nsDialogs::Create 1018
|
||||
Pop $0
|
||||
|
||||
${If} $0 == error
|
||||
Abort
|
||||
${EndIf}
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||
Pop $AutostartCheckbox
|
||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
|
||||
; Function to handle leaving the autostart page
|
||||
Function AutostartPageLeave
|
||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||
FunctionEnd
|
||||
|
||||
Function GetAppFromCommand
|
||||
Exch $1
|
||||
Push $2
|
||||
@@ -196,16 +163,6 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
EnVar::SetHKLM
|
||||
EnVar::AddValueEx "path" "$INSTDIR"
|
||||
|
||||
@@ -229,10 +186,7 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
|
||||
# kill ui client
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
||||
|
||||
# wait the service uninstall take unblock the executable
|
||||
Sleep 3000
|
||||
|
||||
@@ -9,13 +9,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
@@ -49,7 +49,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
@@ -343,7 +343,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
||||
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
|
||||
}
|
||||
|
||||
// Address mocks base method.
|
||||
func (m *MockIFaceMapper) Address() wgaddr.Address {
|
||||
func (m *MockIFaceMapper) Address() iface.WGAddress {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Address")
|
||||
ret0, _ := ret[0].(wgaddr.Address)
|
||||
ret0, _ := ret[0].(iface.WGAddress)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -159,9 +159,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
runningChanOpen := true
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if c.ctx.Err() != nil {
|
||||
if c.isContextCancelled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -281,11 +282,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case runningChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if runningChan != nil && runningChanOpen {
|
||||
runningChan <- nil
|
||||
close(runningChan)
|
||||
runningChanOpen = false
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
@@ -379,6 +379,15 @@ func (c *ConnectClient) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) isContextCancelled() bool {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||
// When enabled, the last received network map will be stored and can be retrieved
|
||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -31,7 +30,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
@@ -41,9 +40,9 @@ func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
func (w *mocWGIface) Address() iface.WGAddress {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return wgaddr.Address{
|
||||
return iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}
|
||||
@@ -459,7 +458,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
@@ -2,7 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -117,10 +117,5 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
@@ -5,15 +5,15 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
|
||||
@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:])
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
@@ -1641,19 +1641,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("restarting engine")
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
log.Infof("cancelling client context, engine will be recreated")
|
||||
log.Infof("cancelling client, engine will be recreated")
|
||||
e.clientCancel()
|
||||
}
|
||||
|
||||
@@ -1665,17 +1662,34 @@ func (e *Engine) startNetworkMonitor() {
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Infof("network monitor stopped")
|
||||
return
|
||||
}
|
||||
log.Errorf("network monitor error: %v", err)
|
||||
return
|
||||
}
|
||||
var mu sync.Mutex
|
||||
var debounceTimer *time.Timer
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
||||
// a network change is detected, or an error occurs on start up
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
// This function is called when a network change is detected
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -76,7 +75,7 @@ type MockWGIface struct {
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
@@ -115,7 +114,7 @@ func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() wgaddr.Address {
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
@@ -365,8 +364,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
@@ -21,7 +20,7 @@ type wgIfaceBase interface {
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/store"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type rcvChan chan *types.EventFields
|
||||
@@ -22,17 +21,15 @@ type Logger struct {
|
||||
enabled atomic.Bool
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancelReceiver context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
|
||||
func New(ctx context.Context) *Logger {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Logger{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
Store: store.NewMemoryStore(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
Store: store.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,14 +58,13 @@ func (l *Logger) startReceiver() {
|
||||
if l.enabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
l.mux.Lock()
|
||||
ctx, cancel := context.WithCancel(l.ctx)
|
||||
l.cancelReceiver = cancel
|
||||
l.mux.Unlock()
|
||||
|
||||
c := make(rcvChan, 100)
|
||||
l.rcvChan.Store(&c)
|
||||
l.rcvChan.Swap(&c)
|
||||
l.enabled.Store(true)
|
||||
|
||||
for {
|
||||
@@ -77,15 +73,12 @@ func (l *Logger) startReceiver() {
|
||||
log.Info("flow Memory store receiver stopped")
|
||||
return
|
||||
case eventFields := <-c:
|
||||
id := uuid.New()
|
||||
id := uuid.NewString()
|
||||
event := types.Event{
|
||||
ID: id,
|
||||
EventFields: *eventFields,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
|
||||
event.SourceResourceID = []byte(srcResId)
|
||||
event.DestResourceID = []byte(dstResId)
|
||||
l.Store.StoreEvent(&event)
|
||||
}
|
||||
}
|
||||
@@ -107,7 +100,6 @@ func (l *Logger) stop() {
|
||||
l.cancelReceiver()
|
||||
l.cancelReceiver = nil
|
||||
}
|
||||
l.rcvChan.Store(nil)
|
||||
l.mux.Unlock()
|
||||
}
|
||||
|
||||
@@ -115,7 +107,7 @@ func (l *Logger) GetEvents() []*types.Event {
|
||||
return l.Store.GetEvents()
|
||||
}
|
||||
|
||||
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
|
||||
func (l *Logger) DeleteEvents(ids []string) {
|
||||
l.Store.DeleteEvents(ids)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(context.Background(), nil)
|
||||
logger := logger.New(context.Background())
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
|
||||
@@ -2,20 +2,17 @@ package netflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/logger"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
)
|
||||
@@ -32,8 +29,8 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
flowLogger := logger.New(ctx, statusRecorder)
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte) *Manager {
|
||||
flowLogger := logger.New(ctx)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -48,80 +45,46 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
|
||||
}
|
||||
}
|
||||
|
||||
// Update applies new flow configuration settings
|
||||
// needsNewClient checks if a new client needs to be created
|
||||
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||
current := m.flowConfig
|
||||
return previous == nil ||
|
||||
!previous.Enabled ||
|
||||
previous.TokenPayload != current.TokenPayload ||
|
||||
previous.TokenSignature != current.TokenSignature ||
|
||||
previous.URL != current.URL
|
||||
}
|
||||
|
||||
// enableFlow starts components for flow tracking
|
||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
// first make sender ready so events don't pile up
|
||||
if m.needsNewClient(previous) {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs(flowClient)
|
||||
go m.startSender()
|
||||
}
|
||||
|
||||
m.logger.Enable()
|
||||
|
||||
if m.conntrack != nil {
|
||||
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
|
||||
return fmt.Errorf("start conntrack: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableFlow stops components for flow tracking
|
||||
func (m *Manager) disableFlow() error {
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update applies new flow configuration settings
|
||||
func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
if update == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
previous := m.flowConfig
|
||||
m.flowConfig = update
|
||||
|
||||
if update.Enabled {
|
||||
return m.enableFlow(previous)
|
||||
if m.conntrack != nil {
|
||||
if err := m.conntrack.Start(update.Counters); err != nil {
|
||||
return fmt.Errorf("start conntrack: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Enable()
|
||||
if previous == nil || !previous.Enabled {
|
||||
flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("flow client connected to %s", m.flowConfig.URL)
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs()
|
||||
go m.startSender()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.disableFlow()
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
m.logger.Disable()
|
||||
if previous != nil && previous.Enabled {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close cleans up all resources
|
||||
@@ -132,13 +95,6 @@ func (m *Manager) Close() {
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Close()
|
||||
}
|
||||
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("failed to close receiver client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Close()
|
||||
}
|
||||
|
||||
@@ -150,7 +106,6 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||
func (m *Manager) startSender() {
|
||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
@@ -158,62 +113,56 @@ func (m *Manager) startSender() {
|
||||
case <-ticker.C:
|
||||
events := m.logger.GetEvents()
|
||||
for _, event := range events {
|
||||
if err := m.send(event); err != nil {
|
||||
log.Errorf("failed to send flow event to server: %s", err)
|
||||
continue
|
||||
log.Infof("send flow event to server: %s", event.ID)
|
||||
err := m.send(event)
|
||||
if err != nil {
|
||||
log.Errorf("send flow event to server: %s", err)
|
||||
}
|
||||
log.Tracef("sent flow event: %s", event.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
log.Tracef("received flow event ack: %s", ack.EventId)
|
||||
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
|
||||
func (m *Manager) receiveACKs() {
|
||||
if m.receiverClient == nil {
|
||||
return
|
||||
}
|
||||
err := m.receiverClient.Receive(m.ctx, func(ack *proto.FlowEventAck) error {
|
||||
log.Infof("receive flow event ack: %s", ack.EventId)
|
||||
m.logger.DeleteEvents([]string{ack.EventId})
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("failed to receive flow event ack: %s", err)
|
||||
if err != nil {
|
||||
log.Errorf("receive flow event ack: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) send(event *nftypes.Event) error {
|
||||
m.mux.Lock()
|
||||
client := m.receiverClient
|
||||
m.mux.Unlock()
|
||||
|
||||
if client == nil {
|
||||
if m.receiverClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return client.Send(toProtoEvent(m.publicKey, event))
|
||||
return m.receiverClient.Send(m.ctx, toProtoEvent(m.publicKey, event))
|
||||
}
|
||||
|
||||
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
|
||||
protoEvent := &proto.FlowEvent{
|
||||
EventId: event.ID[:],
|
||||
EventId: event.ID,
|
||||
Timestamp: timestamppb.New(event.Timestamp),
|
||||
PublicKey: publicKey,
|
||||
FlowFields: &proto.FlowFields{
|
||||
FlowId: event.FlowID[:],
|
||||
RuleId: event.RuleID,
|
||||
Type: proto.Type(event.Type),
|
||||
Direction: proto.Direction(event.Direction),
|
||||
Protocol: uint32(event.Protocol),
|
||||
SourceIp: event.SourceIP.AsSlice(),
|
||||
DestIp: event.DestIP.AsSlice(),
|
||||
RxPackets: event.RxPackets,
|
||||
TxPackets: event.TxPackets,
|
||||
RxBytes: event.RxBytes,
|
||||
TxBytes: event.TxBytes,
|
||||
SourceResourceId: event.SourceResourceID,
|
||||
DestResourceId: event.DestResourceID,
|
||||
FlowId: event.FlowID[:],
|
||||
RuleId: event.RuleID,
|
||||
Type: proto.Type(event.Type),
|
||||
Direction: proto.Direction(event.Direction),
|
||||
Protocol: uint32(event.Protocol),
|
||||
SourceIp: event.SourceIP.AsSlice(),
|
||||
DestIp: event.DestIP.AsSlice(),
|
||||
RxPackets: event.RxPackets,
|
||||
TxPackets: event.TxPackets,
|
||||
RxBytes: event.RxBytes,
|
||||
TxBytes: event.TxBytes,
|
||||
},
|
||||
}
|
||||
|
||||
if event.Protocol == nftypes.ICMP {
|
||||
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
|
||||
IcmpInfo: &proto.ICMPInfo{
|
||||
|
||||
@@ -3,22 +3,18 @@ package store
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
func NewMemoryStore() *Memory {
|
||||
return &Memory{
|
||||
events: make(map[uuid.UUID]*types.Event),
|
||||
events: make(map[string]*types.Event),
|
||||
}
|
||||
}
|
||||
|
||||
type Memory struct {
|
||||
mux sync.Mutex
|
||||
events map[uuid.UUID]*types.Event
|
||||
events map[string]*types.Event
|
||||
}
|
||||
|
||||
func (m *Memory) StoreEvent(event *types.Event) {
|
||||
@@ -30,7 +26,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
|
||||
func (m *Memory) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
maps.Clear(m.events)
|
||||
m.events = make(map[string]*types.Event)
|
||||
}
|
||||
|
||||
func (m *Memory) GetEvents() []*types.Event {
|
||||
@@ -43,7 +39,7 @@ func (m *Memory) GetEvents() []*types.Event {
|
||||
return events
|
||||
}
|
||||
|
||||
func (m *Memory) DeleteEvents(ids []uuid.UUID) {
|
||||
func (m *Memory) DeleteEvents(ids []string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, id := range ids {
|
||||
|
||||
@@ -2,12 +2,11 @@ package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type Protocol uint8
|
||||
@@ -28,10 +27,8 @@ func (p Protocol) String() string {
|
||||
return "TCP"
|
||||
case 17:
|
||||
return "UDP"
|
||||
case 132:
|
||||
return "SCTP"
|
||||
default:
|
||||
return strconv.FormatUint(uint64(p), 10)
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,29 +61,27 @@ const (
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
ID uuid.UUID
|
||||
ID string
|
||||
Timestamp time.Time
|
||||
EventFields
|
||||
}
|
||||
|
||||
type EventFields struct {
|
||||
FlowID uuid.UUID
|
||||
Type Type
|
||||
RuleID []byte
|
||||
Direction Direction
|
||||
Protocol Protocol
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
SourceResourceID []byte
|
||||
DestResourceID []byte
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
RxPackets uint64
|
||||
TxPackets uint64
|
||||
RxBytes uint64
|
||||
TxBytes uint64
|
||||
FlowID uuid.UUID
|
||||
Type Type
|
||||
RuleID []byte
|
||||
Direction Direction
|
||||
Protocol Protocol
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
RxPackets uint64
|
||||
TxPackets uint64
|
||||
RxBytes uint64
|
||||
TxBytes uint64
|
||||
}
|
||||
|
||||
type FlowConfig struct {
|
||||
@@ -113,7 +108,7 @@ type FlowLogger interface {
|
||||
// GetEvents returns all stored events
|
||||
GetEvents() []*Event
|
||||
// DeleteEvents deletes events from the store
|
||||
DeleteEvents([]uuid.UUID)
|
||||
DeleteEvents([]string)
|
||||
// Close closes the logger
|
||||
Close()
|
||||
// Enable enables the flow logger receiver
|
||||
@@ -128,7 +123,7 @@ type Store interface {
|
||||
// GetEvents returns all stored events
|
||||
GetEvents() []*Event
|
||||
// DeleteEvents deletes events from the store
|
||||
DeleteEvents([]uuid.UUID)
|
||||
DeleteEvents([]string)
|
||||
// Close closes the store
|
||||
Close()
|
||||
}
|
||||
@@ -147,5 +142,5 @@ type ConnTracker interface {
|
||||
type IFaceMapper interface {
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() device.WGAddress
|
||||
}
|
||||
|
||||
@@ -1,27 +1,12 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
const (
|
||||
debounceTime = 2 * time.Second
|
||||
)
|
||||
|
||||
var checkChangeFn = checkChange
|
||||
var ErrStopped = errors.New("monitor has been stopped")
|
||||
|
||||
// NetworkMonitor watches for changes in network configuration.
|
||||
type NetworkMonitor struct {
|
||||
@@ -34,99 +19,3 @@ type NetworkMonitor struct {
|
||||
func New() *NetworkMonitor {
|
||||
return &NetworkMonitor{}
|
||||
}
|
||||
|
||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
nw.mu.Lock()
|
||||
if nw.cancel != nil {
|
||||
nw.mu.Unlock()
|
||||
return errors.New("network monitor already started")
|
||||
}
|
||||
|
||||
ctx, nw.cancel = context.WithCancel(ctx)
|
||||
defer nw.cancel()
|
||||
nw.wg.Add(1)
|
||||
nw.mu.Unlock()
|
||||
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
return nil
|
||||
}
|
||||
|
||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||
}
|
||||
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
event := make(chan struct{}, 1)
|
||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||
|
||||
// debounce changes
|
||||
timer := time.NewTimer(0)
|
||||
timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-event:
|
||||
timer.Reset(debounceTime)
|
||||
case <-timer.C:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the network monitor.
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
nw.mu.Lock()
|
||||
defer nw.mu.Unlock()
|
||||
|
||||
if nw.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
nw.cancel()
|
||||
nw.wg.Wait()
|
||||
}
|
||||
|
||||
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
|
||||
for {
|
||||
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
|
||||
close(event)
|
||||
return
|
||||
}
|
||||
// prevent blocking
|
||||
select {
|
||||
case event <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
@@ -28,10 +28,18 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Debugf("Network monitor: closed routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return ErrStopped
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
@@ -68,11 +76,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
go callback()
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
go callback()
|
||||
}
|
||||
}
|
||||
}
|
||||
82
client/internal/networkmonitor/monitor_generic.go
Normal file
82
client/internal/networkmonitor/monitor_generic.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
nw.mu.Lock()
|
||||
ctx, nw.cancel = context.WithCancel(ctx)
|
||||
nw.mu.Unlock()
|
||||
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
return nil
|
||||
}
|
||||
|
||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||
}
|
||||
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the network monitor.
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
nw.mu.Lock()
|
||||
defer nw.mu.Unlock()
|
||||
|
||||
if nw.cancel != nil {
|
||||
nw.cancel()
|
||||
nw.wg.Wait()
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
@@ -31,7 +31,8 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return ErrStopped
|
||||
|
||||
// handle route changes
|
||||
case route := <-routeChan:
|
||||
// default route and main table
|
||||
@@ -42,10 +43,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
// triggered on added/replaced routes
|
||||
case syscall.RTM_NEWROUTE:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -2,21 +2,10 @@
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
import "context"
|
||||
|
||||
type NetworkMonitor struct {
|
||||
}
|
||||
|
||||
// New creates a new network monitor.
|
||||
func New() *NetworkMonitor {
|
||||
return &NetworkMonitor{}
|
||||
}
|
||||
|
||||
func (nw *NetworkMonitor) Listen(_ context.Context) error {
|
||||
return fmt.Errorf("network monitor not supported on mobile platforms")
|
||||
func (nw *NetworkMonitor) Start(context.Context, func()) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
type MocMultiEvent struct {
|
||||
counter int
|
||||
}
|
||||
|
||||
func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
if m.counter == 0 {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
m.counter--
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNetworkMonitor_Close(t *testing.T) {
|
||||
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
nw := New()
|
||||
|
||||
var resErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
resErr = nw.Listen(context.Background())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second) // wait for the goroutine to start
|
||||
nw.Stop()
|
||||
|
||||
<-done
|
||||
if !errors.Is(resErr, context.Canceled) {
|
||||
t.Errorf("unexpected error: %v", resErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkMonitor_Event(t *testing.T) {
|
||||
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timeout.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
nw := New()
|
||||
defer nw.Stop()
|
||||
|
||||
var resErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
resErr = nw.Listen(context.Background())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-done
|
||||
if !errors.Is(resErr, nil) {
|
||||
t.Errorf("unexpected error: %v", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkMonitor_MultiEvent(t *testing.T) {
|
||||
eventsRepeated := 3
|
||||
me := &MocMultiEvent{counter: eventsRepeated}
|
||||
checkChangeFn = me.checkChange
|
||||
|
||||
nw := New()
|
||||
defer nw.Stop()
|
||||
|
||||
done := make(chan struct{})
|
||||
started := time.Now()
|
||||
go func() {
|
||||
if resErr := nw.Listen(context.Background()); resErr != nil {
|
||||
t.Errorf("unexpected error: %v", resErr)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-done
|
||||
expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime
|
||||
if time.Since(started) < expectedResponseTime {
|
||||
t.Errorf("unexpected duration: %v", time.Since(started))
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route monitor: %w", err)
|
||||
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return ErrStopped
|
||||
case route := <-routeMonitor.RouteUpdates():
|
||||
if route.Destination.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if routeChanged(route, nexthopv4, nexthopv6) {
|
||||
return nil
|
||||
if routeChanged(route, nexthopv4, nexthopv6, callback) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
|
||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
@@ -51,15 +51,18 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
|
||||
case systemops.RouteModified:
|
||||
// TODO: get routing table to figure out if our route is affected for modified routes
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
case systemops.RouteAdded:
|
||||
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
|
||||
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
}
|
||||
case systemops.RouteDeleted:
|
||||
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -442,8 +442,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
|
||||
if conn.isICEActive() {
|
||||
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
@@ -711,8 +711,8 @@ func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) isICEActive() bool {
|
||||
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
|
||||
func (conn *Conn) iceP2PIsActive() bool {
|
||||
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
@@ -17,5 +16,4 @@ type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
type routeIDLookup struct {
|
||||
localMap sync.Map
|
||||
remoteMap sync.Map
|
||||
resolvedIPs sync.Map
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
|
||||
_, exists := r.localMap.LoadOrStore(route, resourceID)
|
||||
if exists {
|
||||
log.Tracef("resourceID %s already exists in local map", resourceID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
|
||||
r.localMap.Delete(route)
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
|
||||
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
|
||||
if exists {
|
||||
log.Tracef("resourceID %s already exists in remote map", resourceID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
|
||||
r.remoteMap.Delete(route)
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
|
||||
|
||||
r.resolvedIPs.Store(route.Addr(), resourceID)
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
|
||||
r.resolvedIPs.Delete(route.Addr())
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) Lookup(src, dst netip.Addr, direction nftypes.Direction) (srcResourceID, dstResourceID string) {
|
||||
|
||||
// check resolved ip's first
|
||||
resId, ok := r.resolvedIPs.Load(src)
|
||||
if ok {
|
||||
srcResourceID = resId.(string)
|
||||
} else {
|
||||
resId, ok := r.resolvedIPs.Load(dst)
|
||||
if ok {
|
||||
dstResourceID = resId.(string)
|
||||
}
|
||||
}
|
||||
|
||||
switch direction {
|
||||
case nftypes.Ingress:
|
||||
if srcResourceID == "" || dstResourceID == "" {
|
||||
r.localMap.Range(func(key, value interface{}) bool {
|
||||
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
|
||||
srcResourceID = value.(string)
|
||||
|
||||
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
|
||||
dstResourceID = value.(string)
|
||||
}
|
||||
|
||||
if srcResourceID != "" && dstResourceID != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
case nftypes.Egress:
|
||||
if srcResourceID == "" || dstResourceID == "" {
|
||||
r.remoteMap.Range(func(key, value interface{}) bool {
|
||||
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
|
||||
srcResourceID = value.(string)
|
||||
|
||||
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
|
||||
dstResourceID = value.(string)
|
||||
}
|
||||
|
||||
if srcResourceID != "" && dstResourceID != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return srcResourceID, dstResourceID
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@@ -177,8 +176,6 @@ type Status struct {
|
||||
eventQueue *EventQueue
|
||||
|
||||
ingressGwMgr *ingressgw.Manager
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
@@ -314,7 +311,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -326,14 +323,6 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string)
|
||||
peerState.AddRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
} else {
|
||||
|
||||
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
@@ -351,28 +340,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
peerState.DeleteRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
} else {
|
||||
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
||||
}
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckRoutes checks if the source and destination addresses are within the same route
|
||||
// and returns the resource ID of the route that contains the addresses
|
||||
func (d *Status) CheckRoutes(src, dst netip.Addr, direction nftypes.Direction) (srcResId string, dstResId string) {
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
|
||||
return d.routeIDLookup.Lookup(src, dst, direction)
|
||||
}
|
||||
|
||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
@@ -586,50 +558,6 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
}
|
||||
|
||||
if d.localPeer.Routes == nil {
|
||||
d.localPeer.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
d.localPeer.Routes[route] = struct{}{}
|
||||
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||
func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
}
|
||||
|
||||
delete(d.localPeer.Routes, route)
|
||||
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
}
|
||||
|
||||
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||
func (d *Status) CleanLocalPeerStateRoutes() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Lock()
|
||||
@@ -713,7 +641,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -722,10 +650,6 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
|
||||
Prefixes: prefixes,
|
||||
ParentDomain: originalDomain,
|
||||
}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
@@ -736,10 +660,6 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
for k, v := range d.resolvedDomainsStates {
|
||||
if v.ParentDomain == domain {
|
||||
delete(d.resolvedDomainsStates, k)
|
||||
|
||||
for _, prefix := range v.Prefixes {
|
||||
d.routeIDLookup.RemoveResolvedIP(prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,12 +358,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
@@ -371,8 +365,14 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
|
||||
}
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is handled by route exclusion / ip rules
|
||||
// default route is
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
|
||||
c.connectEvent()
|
||||
}
|
||||
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add peer state route: %w", err)
|
||||
}
|
||||
|
||||
@@ -160,12 +160,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Net: "udp",
|
||||
@@ -321,7 +315,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
||||
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
|
||||
@@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
|
||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||
r.dynamicDomains[domain] = updatedPrefixes
|
||||
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes, r.route.GetResourceID())
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
|
||||
@@ -3,9 +3,9 @@ package iface
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgIfaceBase interface {
|
||||
@@ -13,7 +13,7 @@ type wgIfaceBase interface {
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
|
||||
@@ -103,7 +103,9 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
|
||||
m.statusRecorder.RemoveLocalPeerStateRoute(route.Network.String())
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
delete(state.Routes, route.Network.String())
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -129,12 +131,18 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||
|
||||
m.routes[route.ID] = route
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -156,7 +164,9 @@ func (m *serverRouter) cleanUp() {
|
||||
|
||||
}
|
||||
|
||||
m.statusRecorder.CleanLocalPeerStateRoutes()
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
state.Routes = nil
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
|
||||
@@ -3,7 +3,7 @@ package server
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
|
||||
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
|
||||
stdErrorHandle = ^uintptr(11)
|
||||
)
|
||||
@@ -24,10 +25,13 @@ var (
|
||||
)
|
||||
|
||||
func handlePanicLog() error {
|
||||
// TODO: move this to a central location
|
||||
logDir := path.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||
logPath := path.Join(logDir, "netbird.err")
|
||||
logPath := os.Getenv(windowsPanicLogEnvVar)
|
||||
if logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
return fmt.Errorf("create panic log directory: %w", err)
|
||||
}
|
||||
@@ -35,11 +39,13 @@ func handlePanicLog() error {
|
||||
return fmt.Errorf("enforce permission on panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Open log file with append mode
|
||||
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stderr to the file
|
||||
if err = redirectStderr(f); err != nil {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after redirect error: %v", closeErr)
|
||||
@@ -53,6 +59,7 @@ func handlePanicLog() error {
|
||||
|
||||
// redirectStderr redirects stderr to the provided file
|
||||
func redirectStderr(f *os.File) error {
|
||||
// Get the current process's stderr handle
|
||||
if err := setStdHandle(f); err != nil {
|
||||
return fmt.Errorf("failed to set stderr handle: %w", err)
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func (s *Server) Start() error {
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
|
||||
runningChan chan struct{},
|
||||
runningChan chan error,
|
||||
) {
|
||||
backOff := getConnectWithBackoff(ctx)
|
||||
retryStarted := false
|
||||
@@ -628,21 +628,20 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
|
||||
runningChan := make(chan error)
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-runningChan:
|
||||
return &proto.UpResponse{}, nil
|
||||
case err := <-runningChan:
|
||||
if err != nil {
|
||||
log.Debugf("waiting for engine to become ready failed: %s", err)
|
||||
} else {
|
||||
return &proto.UpResponse{}, nil
|
||||
}
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@@ -42,21 +41,11 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
srcIP = engine.GetWgAddr()
|
||||
}
|
||||
|
||||
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
}
|
||||
|
||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||
if req.GetDestinationIp() == "self" {
|
||||
dstIP = engine.GetWgAddr()
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
}
|
||||
|
||||
if srcIP == nil || dstIP == nil {
|
||||
return nil, fmt.Errorf("invalid IP address")
|
||||
}
|
||||
@@ -96,8 +85,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
}
|
||||
|
||||
builder := &uspfilter.PacketBuilder{
|
||||
SrcIP: srcAddr,
|
||||
DstIP: dstAddr,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: protocol,
|
||||
SrcPort: uint16(req.GetSourcePort()),
|
||||
DstPort: uint16(req.GetDestinationPort()),
|
||||
|
||||
@@ -4,10 +4,8 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -27,99 +25,95 @@ type GRPCClient struct {
|
||||
realClient proto.FlowServiceClient
|
||||
clientConn *grpc.ClientConn
|
||||
stream proto.FlowService_EventsClient
|
||||
streamMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
|
||||
var opts []grpc.DialOption
|
||||
func NewClient(ctx context.Context, addr, payload, signature string) (*GRPCClient, error) {
|
||||
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if strings.Contains(addr, "443") {
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
})))
|
||||
} else {
|
||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
}))
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
grpc.WithIdleTimeout(interval*2),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
withAuthToken(payload, signature),
|
||||
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
|
||||
)
|
||||
|
||||
conn, err := grpc.NewClient(addr, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating new grpc client: %w", err)
|
||||
return nil, fmt.Errorf("dialing with context: %s", err)
|
||||
}
|
||||
|
||||
return &GRPCClient{
|
||||
client := &GRPCClient{
|
||||
realClient: proto.NewFlowServiceClient(conn),
|
||||
clientConn: conn,
|
||||
}, nil
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Close() error {
|
||||
c.streamMu.Lock()
|
||||
defer c.streamMu.Unlock()
|
||||
|
||||
c.stream = nil
|
||||
return c.clientConn.Close()
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
backOff := defaultBackoff(ctx, interval)
|
||||
func (c *GRPCClient) Receive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
backOff := defaultBackoff(ctx)
|
||||
operation := func() error {
|
||||
return c.establishStreamAndReceive(ctx, msgHandler)
|
||||
connState := c.clientConn.GetState()
|
||||
if connState == connectivity.Shutdown {
|
||||
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
|
||||
}
|
||||
|
||||
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.stream = stream
|
||||
|
||||
err = checkHeader(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.receive(stream, msgHandler)
|
||||
}
|
||||
|
||||
if err := backoff.Retry(operation, backOff); err != nil {
|
||||
return fmt.Errorf("receive failed permanently: %w", err)
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Errorf("exiting the flow receiver service connection retry loop due to the unrecoverable error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
if c.clientConn.GetState() == connectivity.Shutdown {
|
||||
return backoff.Permanent(errors.New("connection to flow receiver has been shut down"))
|
||||
}
|
||||
|
||||
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create event stream: %w", err)
|
||||
}
|
||||
|
||||
if err = checkHeader(stream); err != nil {
|
||||
return fmt.Errorf("check header: %w", err)
|
||||
}
|
||||
|
||||
c.streamMu.Lock()
|
||||
c.stream = stream
|
||||
c.streamMu.Unlock()
|
||||
|
||||
return c.receive(stream, msgHandler)
|
||||
}
|
||||
|
||||
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive from stream: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := msgHandler(msg); err != nil {
|
||||
return fmt.Errorf("handle message: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,7 +122,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
|
||||
header, err := stream.Header()
|
||||
if err != nil {
|
||||
log.Errorf("waiting for flow receiver header: %s", err)
|
||||
return fmt.Errorf("wait for header: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(header) == 0 {
|
||||
@@ -138,29 +132,26 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: 1.7,
|
||||
MaxInterval: interval / 2,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||
c.streamMu.Lock()
|
||||
stream := c.stream
|
||||
c.streamMu.Unlock()
|
||||
|
||||
if stream == nil {
|
||||
return errors.New("stream not initialized")
|
||||
func (c *GRPCClient) Send(ctx context.Context, event *proto.FlowEvent) error {
|
||||
if c.stream == nil {
|
||||
return fmt.Errorf("stream not initialized")
|
||||
}
|
||||
|
||||
if err := stream.Send(event); err != nil {
|
||||
return fmt.Errorf("send flow event: %w", err)
|
||||
err := c.stream.Send(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending flow event: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -130,7 +130,7 @@ type FlowEvent struct {
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Unique client event identifier
|
||||
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
EventId string `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
// When the event occurred
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
// Public key of the sending peer
|
||||
@@ -170,11 +170,11 @@ func (*FlowEvent) Descriptor() ([]byte, []int) {
|
||||
return file_flow_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *FlowEvent) GetEventId() []byte {
|
||||
func (x *FlowEvent) GetEventId() string {
|
||||
if x != nil {
|
||||
return x.EventId
|
||||
}
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *FlowEvent) GetTimestamp() *timestamppb.Timestamp {
|
||||
@@ -204,7 +204,7 @@ type FlowEventAck struct {
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Unique client event identifier that has been ack'ed
|
||||
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
EventId string `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
}
|
||||
|
||||
func (x *FlowEventAck) Reset() {
|
||||
@@ -239,11 +239,11 @@ func (*FlowEventAck) Descriptor() ([]byte, []int) {
|
||||
return file_flow_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *FlowEventAck) GetEventId() []byte {
|
||||
func (x *FlowEventAck) GetEventId() string {
|
||||
if x != nil {
|
||||
return x.EventId
|
||||
}
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
|
||||
type FlowFields struct {
|
||||
@@ -278,9 +278,6 @@ type FlowFields struct {
|
||||
// Number of bytes
|
||||
RxBytes uint64 `protobuf:"varint,12,opt,name=rx_bytes,json=rxBytes,proto3" json:"rx_bytes,omitempty"`
|
||||
TxBytes uint64 `protobuf:"varint,13,opt,name=tx_bytes,json=txBytes,proto3" json:"tx_bytes,omitempty"`
|
||||
// Resource ID
|
||||
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
|
||||
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
|
||||
}
|
||||
|
||||
func (x *FlowFields) Reset() {
|
||||
@@ -413,20 +410,6 @@ func (x *FlowFields) GetTxBytes() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *FlowFields) GetSourceResourceId() []byte {
|
||||
if x != nil {
|
||||
return x.SourceResourceId
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *FlowFields) GetDestResourceId() []byte {
|
||||
if x != nil {
|
||||
return x.DestResourceId
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isFlowFields_ConnectionInfo interface {
|
||||
isFlowFields_ConnectionInfo()
|
||||
}
|
||||
@@ -565,7 +548,7 @@ var file_flow_proto_rawDesc = []byte{
|
||||
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x22, 0xb2, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
|
||||
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
|
||||
0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75,
|
||||
0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d,
|
||||
@@ -576,8 +559,8 @@ var file_flow_proto_rawDesc = []byte{
|
||||
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
|
||||
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x22, 0x29, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77,
|
||||
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x49, 0x64, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c,
|
||||
0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x49, 0x64, 0x22, 0xc4, 0x03, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c,
|
||||
0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74,
|
||||
0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77,
|
||||
@@ -604,36 +587,31 @@ var file_flow_proto_rawDesc = []byte{
|
||||
0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20,
|
||||
0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08,
|
||||
0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07,
|
||||
0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x6f, 0x75, 0x72, 0x63,
|
||||
0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0e, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x65, 0x73, 0x6f, 0x75,
|
||||
0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65,
|
||||
0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0c, 0x52,
|
||||
0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x42,
|
||||
0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e,
|
||||
0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f,
|
||||
0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12,
|
||||
0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01,
|
||||
0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08,
|
||||
0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
|
||||
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d,
|
||||
0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f,
|
||||
0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f,
|
||||
0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59,
|
||||
0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a,
|
||||
0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08,
|
||||
0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59,
|
||||
0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72,
|
||||
0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54,
|
||||
0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a,
|
||||
0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47,
|
||||
0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12,
|
||||
0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74,
|
||||
0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65,
|
||||
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f,
|
||||
0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65,
|
||||
0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75,
|
||||
0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f,
|
||||
0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74,
|
||||
0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f,
|
||||
0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a,
|
||||
0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d,
|
||||
0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79,
|
||||
0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f,
|
||||
0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41,
|
||||
0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44,
|
||||
0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10,
|
||||
0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15,
|
||||
0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e,
|
||||
0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53,
|
||||
0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42,
|
||||
0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a,
|
||||
0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46,
|
||||
0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e,
|
||||
0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01,
|
||||
0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -13,7 +13,7 @@ service FlowService {
|
||||
|
||||
message FlowEvent {
|
||||
// Unique client event identifier
|
||||
bytes event_id = 1;
|
||||
string event_id = 1;
|
||||
|
||||
// When the event occurred
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
@@ -26,7 +26,7 @@ message FlowEvent {
|
||||
|
||||
message FlowEventAck {
|
||||
// Unique client event identifier that has been ack'ed
|
||||
bytes event_id = 1;
|
||||
string event_id = 1;
|
||||
}
|
||||
|
||||
message FlowFields {
|
||||
@@ -67,11 +67,6 @@ message FlowFields {
|
||||
// Number of bytes
|
||||
uint64 rx_bytes = 12;
|
||||
uint64 tx_bytes = 13;
|
||||
|
||||
// Resource ID
|
||||
bytes source_resource_id = 14;
|
||||
bytes dest_resource_id = 15;
|
||||
|
||||
}
|
||||
|
||||
// Flow event types
|
||||
|
||||
83
formatter/formatter.go
Normal file
83
formatter/formatter.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package formatter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TextFormatter formats logs into text with included source code's path
|
||||
type TextFormatter struct {
|
||||
timestampFormat string
|
||||
levelDesc []string
|
||||
}
|
||||
|
||||
// SyslogFormatter formats logs into text
|
||||
type SyslogFormatter struct {
|
||||
levelDesc []string
|
||||
}
|
||||
|
||||
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
|
||||
|
||||
|
||||
// NewTextFormatter create new MyTextFormatter instance
|
||||
func NewTextFormatter() *TextFormatter {
|
||||
return &TextFormatter{
|
||||
levelDesc: validLevelDesc,
|
||||
timestampFormat: time.RFC3339, // or RFC3339
|
||||
}
|
||||
}
|
||||
|
||||
// NewSyslogFormatter create new MySyslogFormatter instance
|
||||
func NewSyslogFormatter() *SyslogFormatter {
|
||||
return &SyslogFormatter{
|
||||
levelDesc: validLevelDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Format renders a single log entry
|
||||
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
var fields string
|
||||
keys := make([]string, 0, len(entry.Data))
|
||||
for k, v := range entry.Data {
|
||||
if k == "source" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
|
||||
}
|
||||
|
||||
level := f.parseLevel(entry.Level)
|
||||
|
||||
return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.timestampFormat), level, fields, entry.Data["source"], entry.Message)), nil
|
||||
}
|
||||
|
||||
func (f *TextFormatter) parseLevel(level logrus.Level) string {
|
||||
if len(f.levelDesc) < int(level) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return f.levelDesc[level]
|
||||
}
|
||||
|
||||
// Format renders a single log entry
|
||||
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
var fields string
|
||||
keys := make([]string, 0, len(entry.Data))
|
||||
for k, v := range entry.Data {
|
||||
if k == "source" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package txt
|
||||
package formatter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -24,3 +24,20 @@ func TestLogTextFormat(t *testing.T) {
|
||||
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
||||
func TestLogSyslogFormat(t *testing.T) {
|
||||
|
||||
someEntry := &logrus.Entry{
|
||||
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
|
||||
Level: 3,
|
||||
Message: "Some Message",
|
||||
}
|
||||
|
||||
formatter := NewSyslogFormatter()
|
||||
result, _ := formatter.Format(someEntry)
|
||||
|
||||
parsedString := string(result)
|
||||
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package hook
|
||||
package formatter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -41,8 +41,7 @@ func (hook ContextHook) Levels() []logrus.Level {
|
||||
// Fire extend with the source information the entry.Data
|
||||
func (hook ContextHook) Fire(entry *logrus.Entry) error {
|
||||
src := hook.parseSrc(entry.Caller.File)
|
||||
entry.Data[EntryKeySource] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
|
||||
additionalEntries(entry)
|
||||
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
|
||||
|
||||
if entry.Context == nil {
|
||||
return nil
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !loggoroutine
|
||||
|
||||
package hook
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
func additionalEntries(_ *log.Entry) {
|
||||
// This function is empty and is used to demonstrate the use of additional hooks.
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user