mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 07:04:17 -04:00
Compare commits
80 Commits
feature/fl
...
test/netwo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d2c774378 | ||
|
|
ab2e3fec72 | ||
|
|
47f88f7057 | ||
|
|
ee33a6ed7c | ||
|
|
da662cfd08 | ||
|
|
ed2ee1ee9d | ||
|
|
76d73548d6 | ||
|
|
11828a064a | ||
|
|
0c2a3dd937 | ||
|
|
cd9eff5331 | ||
|
|
47dcf8d68c | ||
|
|
80ceb80197 | ||
|
|
cc8f6bcaf3 | ||
|
|
636a0e2475 | ||
|
|
e66e329bf6 | ||
|
|
aaa23beeec | ||
|
|
6bef474e9e | ||
|
|
81040ff80a | ||
|
|
c73481aee4 | ||
|
|
92286b2541 | ||
|
|
d8bcf745b0 | ||
|
|
8430139d80 | ||
|
|
a2962b4ce0 | ||
|
|
16fffdb75b | ||
|
|
036cecbf46 | ||
|
|
3482852bb6 | ||
|
|
fd62665b1f | ||
|
|
fc1da94520 | ||
|
|
1ffe48f0d4 | ||
|
|
a3b8a21385 | ||
|
|
86492b88c4 | ||
|
|
d08a629f9e | ||
|
|
36da464413 | ||
|
|
268e3404d3 | ||
|
|
54d0591833 | ||
|
|
ae6b61301c | ||
|
|
86370a0e7b | ||
|
|
a444e551b3 | ||
|
|
53b9a2002f | ||
|
|
cb16d0f45f | ||
|
|
e8d8bd8f18 | ||
|
|
8b07f21c28 | ||
|
|
54be772ffd | ||
|
|
4b76d93cec | ||
|
|
3c3a454e61 | ||
|
|
5ff77b3595 | ||
|
|
b180edbe5c | ||
|
|
de3b5c78d7 | ||
|
|
0b42f40cf6 | ||
|
|
062d1ec76f | ||
|
|
0a042ac36d | ||
|
|
c111675dd8 | ||
|
|
e7f921d787 | ||
|
|
e9f11fb11b | ||
|
|
419ed275fa | ||
|
|
2d4fcaf186 | ||
|
|
acf172b52c | ||
|
|
8c81a823fa | ||
|
|
619c549547 | ||
|
|
60ffe0dc87 | ||
|
|
9a713a0987 | ||
|
|
c4945cd565 | ||
|
|
1e10c17ecb | ||
|
|
bcc5824980 | ||
|
|
af5796de1c | ||
|
|
9d604b7e66 | ||
|
|
96d5190436 | ||
|
|
d19c26df06 | ||
|
|
36e36414d9 | ||
|
|
7e69589e05 | ||
|
|
aa613ab79a | ||
|
|
6ead0ff95e | ||
|
|
0db65a8984 | ||
|
|
c138807e95 | ||
|
|
637c0c8949 | ||
|
|
c72e13d8e6 | ||
|
|
f6d7bccfa0 | ||
|
|
e3ed01cafb | ||
|
|
fa748a7ec2 | ||
|
|
82c12cc8ae |
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -dA output:**
|
||||
**Is any other VPN software installed?**
|
||||
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
If yes, which one?
|
||||
|
||||
**Do you face any (non-mobile) client issues?**
|
||||
**Debug output**
|
||||
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
To help us resolve the problem, please attach the following debug output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
As well as the file created by
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
We advise reviewing the anonymized output for any remaining personal information.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -258,7 +258,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -325,8 +325,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -392,7 +392,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -461,7 +461,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
@@ -150,7 +150,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
|
||||
@@ -53,9 +53,9 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird.png
|
||||
- src: client/ui/assets/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -72,9 +72,9 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird.png
|
||||
- src: client/ui/assets/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
|
||||
@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan error, 1)
|
||||
run := make(chan struct{}, 1)
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
run <- err
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-run:
|
||||
if err != nil {
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
||||
}
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
}
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
@@ -10,17 +10,18 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
|
||||
@@ -75,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ type Manager struct {
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -166,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
|
||||
@@ -10,15 +10,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() iface.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -97,17 +97,17 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -121,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -128,7 +129,7 @@ func (r *router) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
|
||||
@@ -4,21 +4,20 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
if err := ipt.Reset(nil); err != nil {
|
||||
if err := ipt.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -65,13 +65,13 @@ type Manager interface {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -80,7 +80,15 @@ type Manager interface {
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
@@ -94,8 +102,8 @@ type Manager interface {
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
// Close closes the firewall manager
|
||||
Close(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
@@ -102,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -256,7 +256,6 @@ func (m *AclManager) addIOFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
@@ -338,7 +337,7 @@ func (m *AclManager) addIOFiltering(
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||
userData := []byte(ruleId)
|
||||
|
||||
chain := m.chainInputRules
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ const (
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Reset() without needing to store specific rules.
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -242,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
||||
@@ -16,15 +16,15 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() iface.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
if err := manager.Close(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Reset(nil)
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -228,6 +228,7 @@ func (r *router) createContainers() error {
|
||||
|
||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -236,7 +237,7 @@ func (r *router) AddRouteFiltering(
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
|
||||
@@ -3,21 +3,20 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
if err := nft.Reset(nil); err != nil {
|
||||
if err := nft.Close(nil); err != nil {
|
||||
return fmt.Errorf("reset nftables manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,39 +4,36 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
@@ -48,7 +45,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -22,30 +23,30 @@ const (
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
||||
@@ -3,14 +3,14 @@ package common
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
// common.go
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
FlowId uuid.UUID
|
||||
Direction nftypes.Direction
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
lastSeen atomic.Int64
|
||||
PacketsTx atomic.Uint64
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// UpdateCounters safely updates the packet and byte counters
|
||||
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||
if direction == nftypes.Egress {
|
||||
b.PacketsTx.Add(1)
|
||||
b.BytesTx.Add(uint64(bytes))
|
||||
} else {
|
||||
b.PacketsRx.Add(1)
|
||||
b.BytesRx.Add(uint64(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||
return time.Since(lastSeen) > timeout
|
||||
}
|
||||
|
||||
// IPAddr is a fixed-size IP address to avoid allocations
|
||||
type IPAddr [16]byte
|
||||
|
||||
// MakeIPAddr creates an IPAddr from net.IP
|
||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
||||
// Optimization: check for v4 first as it's more common
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
copy(addr[12:], ip4)
|
||||
} else {
|
||||
copy(addr[:], ip.To16())
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// ConnKey uniquely identifies a connection
|
||||
type ConnKey struct {
|
||||
SrcIP IPAddr
|
||||
DstIP IPAddr
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// makeConnKey creates a connection key
|
||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||
return ConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateIPs checks if IPs match without allocation
|
||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
||||
if ip4 := pktIP.To4(); ip4 != nil {
|
||||
// Compare IPv4 addresses (last 4 bytes)
|
||||
for i := 0; i < 4; i++ {
|
||||
if connIP[12+i] != ip4[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Compare full IPv6 addresses
|
||||
ip6 := pktIP.To16()
|
||||
for i := 0; i < 16; i++ {
|
||||
if connIP[i] != ip6[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
||||
type PreallocatedIPs struct {
|
||||
sync.Pool
|
||||
}
|
||||
|
||||
// NewPreallocatedIPs creates a new IP pool
|
||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
||||
return &PreallocatedIPs{
|
||||
Pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
ip := make(net.IP, 16)
|
||||
return &ip
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an IP from the pool
|
||||
func (p *PreallocatedIPs) Get() net.IP {
|
||||
return *p.Pool.Get().(*net.IP)
|
||||
}
|
||||
|
||||
// Put returns an IP to the pool
|
||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
||||
p.Pool.Put(&ip)
|
||||
}
|
||||
|
||||
// copyIP copies an IP address efficiently
|
||||
func copyIP(dst, src net.IP) {
|
||||
if len(src) == 16 {
|
||||
copy(dst, src)
|
||||
} else {
|
||||
// Handle IPv4
|
||||
copy(dst[12:], src.To4())
|
||||
}
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
@@ -1,94 +1,67 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = MakeIPAddr(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ValidateIPs", func(b *testing.B) {
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.1")
|
||||
addr := MakeIPAddr(ip1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ValidateIPs(addr, ip2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IPPool", func(b *testing.B) {
|
||||
pool := NewPreallocatedIPs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ip := pool.Get()
|
||||
pool.Put(ip)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,18 +23,20 @@ const (
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
type ICMPConnKey struct {
|
||||
// Supports both IPv4 and IPv6
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
Sequence uint16 // ICMP sequence number
|
||||
ID uint16 // ICMP identifier
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
ID uint16
|
||||
}
|
||||
|
||||
func (i ICMPConnKey) String() string {
|
||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
type ICMPConnTrack struct {
|
||||
BaseConnTrack
|
||||
Sequence uint16
|
||||
ID uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
}
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
@@ -39,131 +45,201 @@ type ICMPTracker struct {
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
tickerCancel: cancel,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
},
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||
key := ICMPConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
ID: id,
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
return key, true
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
return key, false
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine() {
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
typ, code := typecode.Type(), typecode.Code()
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
|
||||
conn := &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := ICMPConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
ID: id,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.tickerCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
t.tickerCancel()
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// makeICMPKey creates an ICMP connection key
|
||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||
return ICMPConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
ICMPCode: conn.ICMPCode,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
}
|
||||
if direction == nftypes.Ingress {
|
||||
fields.RxPackets = 1
|
||||
fields.RxBytes = uint64(size)
|
||||
} else {
|
||||
fields.TxPackets = 1
|
||||
fields.TxBytes = uint64(size)
|
||||
}
|
||||
t.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,12 +3,16 @@ package conntrack
|
||||
// TODO: Send RST packets for invalid/timed-out connections
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,6 +43,35 @@ const (
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int
|
||||
|
||||
func (s TCPState) String() string {
|
||||
switch s {
|
||||
case TCPStateNew:
|
||||
return "New"
|
||||
case TCPStateSynSent:
|
||||
return "SYN Sent"
|
||||
case TCPStateSynReceived:
|
||||
return "SYN Received"
|
||||
case TCPStateEstablished:
|
||||
return "Established"
|
||||
case TCPStateFinWait1:
|
||||
return "FIN Wait 1"
|
||||
case TCPStateFinWait2:
|
||||
return "FIN Wait 2"
|
||||
case TCPStateClosing:
|
||||
return "Closing"
|
||||
case TCPStateTimeWait:
|
||||
return "Time Wait"
|
||||
case TCPStateCloseWait:
|
||||
return "Close Wait"
|
||||
case TCPStateLastAck:
|
||||
return "Last ACK"
|
||||
case TCPStateClosed:
|
||||
return "Closed"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
TCPStateNew TCPState = iota
|
||||
TCPStateSynSent
|
||||
@@ -53,19 +86,14 @@ const (
|
||||
TCPStateClosed
|
||||
)
|
||||
|
||||
// TCPConnKey uniquely identifies a TCP connection
|
||||
type TCPConnKey struct {
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
tombstone atomic.Bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -79,78 +107,126 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
}
|
||||
|
||||
// IsTombstone safely checks if the connection is marked for deletion
|
||||
func (t *TCPConnTrack) IsTombstone() bool {
|
||||
return t.tombstone.Load()
|
||||
}
|
||||
|
||||
// SetTombstone safely marks the connection for deletion
|
||||
func (t *TCPConnTrack) SetTombstone() {
|
||||
t.tombstone.Store(true)
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
done chan struct{}
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
ipPool *PreallocatedIPs
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTCPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||
conn.Unlock()
|
||||
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
return key, true
|
||||
}
|
||||
|
||||
return key, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.established.Store(false)
|
||||
conn.tombstone.Store(false)
|
||||
|
||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
// Use preallocated IPs
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
// Lock individual connection for state update
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.UpdateLastSeen()
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
@@ -159,22 +235,26 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle RST packets
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
conn.Lock()
|
||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
if conn.IsTombstone() {
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
conn.SetTombstone()
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
return false
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
t.updateState(key, conn, flags, false)
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
@@ -183,18 +263,17 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
return
|
||||
}
|
||||
state := conn.State
|
||||
defer func() {
|
||||
if state != conn.State {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||
}
|
||||
}()
|
||||
|
||||
switch conn.State {
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
conn.State = TCPStateSynSent
|
||||
@@ -203,11 +282,11 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateSynReceived
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,22 +304,32 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
conn.State = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
} else if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
// Simultaneous close - both sides sent FIN
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPRst != 0:
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
|
||||
t.logger.Trace("TCP connection %s completed", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
@@ -248,8 +337,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
@@ -260,17 +349,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
|
||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
// Send close event for gracefully closed connections
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||
}
|
||||
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
|
||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine() {
|
||||
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -331,6 +417,12 @@ func (t *TCPTracker) cleanup() {
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.IsTombstone() {
|
||||
// Clean up tombstoned connections without sending an event
|
||||
delete(t.connections, key)
|
||||
continue
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
@@ -341,29 +433,26 @@ func (t *TCPTracker) cleanup() {
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
lastSeen := conn.GetLastSeen()
|
||||
if time.Since(lastSeen) > timeout {
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
// Return IPs to pool
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||
|
||||
// event already handled by state change
|
||||
if conn.State != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
t.tickerCancel()
|
||||
|
||||
// Clean up all remaining IPs
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
@@ -381,3 +470,21 @@ func isValidFlagCombination(flags uint8) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||
})
|
||||
}
|
||||
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Send initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive SYN-ACK
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
// Send ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Test data transfer
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||
require.True(t, valid, "Data should be allowed after handshake")
|
||||
},
|
||||
},
|
||||
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "ACK for FIN should be allowed")
|
||||
|
||||
// Receive FIN from other side
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "FIN should be allowed")
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Both sides send FIN+ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||
|
||||
// Both sides send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "Final ACKs should be allowed")
|
||||
},
|
||||
},
|
||||
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST in established",
|
||||
setupState: func() {
|
||||
// Establish connection first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
},
|
||||
wantValid: true,
|
||||
desc: "Should accept RST for established connection",
|
||||
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST without connection",
|
||||
setupState: func() {},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
},
|
||||
wantValid: false,
|
||||
desc: "Should reject RST without connection",
|
||||
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
tt.sendRST()
|
||||
|
||||
// Verify connection state is as expected
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
@@ -220,63 +225,63 @@ func TestRSTHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
@@ -287,14 +292,14 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,6 +22,8 @@ const (
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
type UDPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
}
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
@@ -26,89 +32,125 @@ type UDPTracker struct {
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
tickerCancel: cancel,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %v", conn)
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
return key, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine() {
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -120,44 +162,58 @@ func (t *UDPTracker) cleanup() {
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *UDPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
t.tickerCancel()
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
return conn, true
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// Timeout returns the configured timeout duration for the tracker
|
||||
func (t *UDPTracker) Timeout() time.Duration {
|
||||
return t.timeout
|
||||
}
|
||||
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -29,54 +30,59 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout, logger)
|
||||
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
assert.NotNil(t, tracker.cleanupTicker)
|
||||
assert.NotNil(t, tracker.done)
|
||||
assert.NotNil(t, tracker.tickerCancel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := tracker.connections[key]
|
||||
require.True(t, exists)
|
||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1*time.Second, logger)
|
||||
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Track outbound connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
sleep time.Duration
|
||||
@@ -93,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid source IP",
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
@@ -103,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
{
|
||||
name: "invalid destination IP",
|
||||
srcIP: dstIP,
|
||||
dstIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
@@ -143,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
if tt.sleep > 0 {
|
||||
time.Sleep(tt.sleep)
|
||||
}
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@@ -154,42 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
timeout := 50 * time.Millisecond
|
||||
cleanupInterval := 25 * time.Millisecond
|
||||
|
||||
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||
defer tickerCancel()
|
||||
|
||||
// Create tracker with custom cleanup interval
|
||||
tracker := &UDPTracker{
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
tickerCancel: tickerCancel,
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go tracker.cleanupRoutine()
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
|
||||
// Add some connections
|
||||
connections := []struct {
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.2"),
|
||||
dstIP: net.ParseIP("192.168.1.3"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
},
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.5"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||
srcPort: 12346,
|
||||
dstPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
@@ -211,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,6 +30,7 @@ const (
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
@@ -38,7 +40,7 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger),
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
|
||||
@@ -3,14 +3,30 @@ package forwarder
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return true
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
return true
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
// TODO: get packets/bytes
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,24 +5,38 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep)
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
@@ -16,6 +18,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,15 +31,17 @@ type udpPacketConn struct {
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
ep tcpip.Endpoint
|
||||
flowID uuid.UUID
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
@@ -44,13 +49,14 @@ type idleConn struct {
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, mtu)
|
||||
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
conn.ep.Close()
|
||||
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||
}
|
||||
|
||||
idle.conn.ep.Close()
|
||||
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||
return
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if epErr != nil {
|
||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
ep: ep,
|
||||
flowID: flowID,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||
}
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
c.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
@@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
if ip.Is4() {
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -2,90 +2,91 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
setupAddr wgaddr.Address
|
||||
testIP netip.Addr
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: iface.WGAddress{
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
@@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
@@ -13,13 +13,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
maxBatchSize = 1024 * 16
|
||||
maxMessageSize = 1024 * 2
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
logChannelSize = 1000
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
type logMessage struct {
|
||||
level Level
|
||||
format string
|
||||
args []any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
msgChannel chan logMessage
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
output: logrusLogger.Out,
|
||||
msgChannel: make(chan logMessage, logChannelSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
New: func() any {
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logrusLevel := logrusLogger.GetLevel()
|
||||
l.level.Store(uint32(logrusLevel))
|
||||
level := levelStrings[Level(logrusLevel)]
|
||||
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
// SetLevel sets the logging level
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
func (l *Logger) log(level Level, format string, args ...any) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||
default:
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
_, _ = l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
// Error logs a message at error level
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
// Warn logs a message at warning level
|
||||
func (l *Logger) Warn(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
// Info logs a message at info level
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
// Debug logs a message at debug level
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
// Trace logs a message at trace level
|
||||
func (l *Logger) Trace(format string, args ...any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||
*buf = (*buf)[:0]
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
var msg string
|
||||
if len(args) > 0 {
|
||||
msg = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
msg = format
|
||||
}
|
||||
*buf = append(*buf, msg...)
|
||||
*buf = append(*buf, '\n')
|
||||
|
||||
if len(*buf) > maxMessageSize {
|
||||
*buf = (*buf)[:maxMessageSize]
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage handles a single log message and adds it to the buffer
|
||||
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
defer l.bufPool.Put(bufp)
|
||||
|
||||
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||
|
||||
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||
_, _ = l.output.Write(*buffer)
|
||||
*buffer = (*buffer)[:0]
|
||||
}
|
||||
*buffer = append(*buffer, *bufp...)
|
||||
}
|
||||
|
||||
// flushBuffer writes the accumulated buffer to output
|
||||
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||
if len(*buffer) > 0 {
|
||||
_, _ = l.output.Write(*buffer)
|
||||
*buffer = (*buffer)[:0]
|
||||
}
|
||||
}
|
||||
|
||||
// processBatch processes as many messages as possible without blocking
|
||||
func (l *Logger) processBatch(buffer *[]byte) {
|
||||
for len(*buffer) < maxBatchSize {
|
||||
select {
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, buffer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, buffer)
|
||||
case <-ctx.Done():
|
||||
l.flushBuffer(buffer)
|
||||
return
|
||||
}
|
||||
|
||||
if len(l.msgChannel) == 0 {
|
||||
l.flushBuffer(buffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker is the main goroutine that processes log messages
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
buffer := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
l.handleShutdown(&buffer)
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
_, _ = l.output.Write(buf[:n])
|
||||
l.flushBuffer(&buffer)
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, &buffer)
|
||||
l.processBatch(&buffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package log_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
type discard struct{}
|
||||
|
||||
func (d *discard) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func BenchmarkLogger(b *testing.B) {
|
||||
simpleMessage := "Connection established"
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4 // TCPStateEstablished
|
||||
|
||||
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||
protocol := "TCP"
|
||||
direction := "outbound"
|
||||
flags := uint16(0x18) // ACK + PSH
|
||||
sequence := uint32(123456789)
|
||||
acknowledged := uint32(987654321)
|
||||
payloadSize := 1460
|
||||
fragmented := false
|
||||
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||
|
||||
b.Run("SimpleMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(simpleMessage)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ComplexMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||
func BenchmarkLoggerParallel(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||
func BenchmarkLoggerBurst(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < 100; j++ {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTestLogger() *log.Logger {
|
||||
logrusLogger := logrus.New()
|
||||
logrusLogger.SetOutput(&discard{})
|
||||
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||
return log.NewFromLogrus(logrusLogger)
|
||||
}
|
||||
|
||||
func cleanupLogger(logger *log.Logger) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = logger.Stop(ctx)
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -12,14 +11,14 @@ import (
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
comment string
|
||||
|
||||
udpHook func([]byte) bool
|
||||
}
|
||||
@@ -31,6 +30,7 @@ func (r *PeerRule) ID() string {
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
|
||||
@@ -2,7 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
SourceIP netip.Addr
|
||||
DestinationIP netip.Addr
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter {
|
||||
if m.nativeRouter.Load() {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
|
||||
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
strRuleId := "<no id>"
|
||||
if ruleId != nil {
|
||||
strRuleId = string(ruleId)
|
||||
}
|
||||
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||
if blocked {
|
||||
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||
trace.AddResult(StagePeerACL, msg, false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
msg := "Allowed by peer ACL rules"
|
||||
if blocked {
|
||||
msg = "Blocked by peer ACL rules"
|
||||
}
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
trace.AddResult(StagePeerACL, msg, true)
|
||||
|
||||
// Handle netstack mode
|
||||
if m.netstack {
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
switch {
|
||||
case !m.localForwarding:
|
||||
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||
case m.forwarder.Load() != nil:
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
default:
|
||||
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
// In normal mode, packets are allowed through for local delivery
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled {
|
||||
if !m.routingEnabled.Load() {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto := getProtocolFromPacket(d)
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
msg := "Allowed by route ACLs"
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
strId = "<no id>"
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||
if !allowed {
|
||||
msg = "Blocked by route ACLs"
|
||||
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder != nil {
|
||||
if allowed && m.forwarder.Load() != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||
t.Logf("Trace results: %v", trace.Results)
|
||||
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||
for _, result := range trace.Results {
|
||||
actualStages = append(actualStages, result.Stage)
|
||||
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||
}
|
||||
|
||||
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||
lastResult := trace.Results[len(trace.Results)-1]
|
||||
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||
}
|
||||
|
||||
func TestTracePacket(t *testing.T) {
|
||||
setupTracerTest := func(statefulMode bool) *Manager {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
m.stateful = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||
builder := &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: protocol,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Direction: direction,
|
||||
}
|
||||
|
||||
if protocol == "tcp" {
|
||||
builder.TCPState = &TCPState{SYN: true}
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||
return &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: "icmp",
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Direction: direction,
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(*Manager)
|
||||
packetBuilder func() *PacketBuilder
|
||||
expectedStages []PacketStage
|
||||
expectedAllow bool
|
||||
}{
|
||||
{
|
||||
name: "LocalTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = true
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithoutForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_NativeRouter",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_RoutingDisabled",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(false)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "ConnectionTracking_Hit",
|
||||
setup: func(m *Manager) {
|
||||
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||
return pb
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "OutboundTraffic",
|
||||
setup: func(m *Manager) {
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPEchoRequest",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPDestinationUnreachable",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithoutHook",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "StatefulDisabled_NoTracking",
|
||||
setup: func(m *Manager) {
|
||||
m.stateful = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
m := setupTracerTest(true)
|
||||
|
||||
tc.setup(m)
|
||||
|
||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||
"100.10.0.100 should be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||
"172.17.0.2 should not be recognized as a local IP")
|
||||
|
||||
pb := tc.packetBuilder()
|
||||
|
||||
trace, err := m.TracePacketFromBuilder(pb)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyTraceStages(t, trace, tc.expectedStages)
|
||||
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -65,9 +67,9 @@ func (r RouteRules) Sort() {
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
@@ -79,9 +81,9 @@ type Manager struct {
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
routingEnabled atomic.Bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter bool
|
||||
nativeRouter atomic.Bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
@@ -94,8 +96,9 @@ type Manager struct {
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -112,16 +115,16 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
|
||||
return disableConntrack, enableLocalForwarding
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
@@ -166,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@@ -185,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
@@ -208,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder == nil {
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
@@ -218,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.ProtocolALL,
|
||||
@@ -251,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
@@ -272,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
@@ -293,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder != nil {
|
||||
if m.forwarder.Load() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled = false
|
||||
m.routingEnabled.Store(false)
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled = false
|
||||
m.routingEnabled.Store(false)
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
m.forwarder.Store(forwarder)
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
@@ -326,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
|
||||
@@ -337,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return nil
|
||||
@@ -348,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
mgmtId: id,
|
||||
ip: i,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
if i.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
r.ip = ipNormalized
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
@@ -391,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
@@ -407,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
@@ -425,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@@ -461,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
delete(m.incomingRules[r.ip], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -497,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.processOutgoingHooks(packetData)
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
}
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData)
|
||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||
return m.dropFilter(packetData, size)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
@@ -511,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -527,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return false
|
||||
}
|
||||
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return d.ip4.SrcIP, d.ip4.DstIP
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
case layers.LayerTypeIPv6:
|
||||
return d.ip6.SrcIP, d.ip6.DstIP
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
default:
|
||||
return nil, nil
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
flags,
|
||||
)
|
||||
}
|
||||
|
||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
var flags uint8
|
||||
if tcp.SYN {
|
||||
@@ -596,45 +591,70 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
m.udpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
)
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||
}
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||
m.icmpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
)
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -643,19 +663,19 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
@@ -663,10 +683,29 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
if blocked {
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
RxPackets: 1,
|
||||
RxBytes: uint64(size),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -675,6 +714,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
// track inbound packets to get the correct direction and session id for flows
|
||||
m.trackInbound(d, srcIP, dstIP, ruleID, size)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -684,12 +726,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.forwarder == nil {
|
||||
if m.forwarder.Load() == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
|
||||
@@ -699,30 +741,43 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
if !m.routingEnabled.Load() {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter {
|
||||
if m.nativeRouter.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
proto := getProtocolFromPacket(d)
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
@@ -730,16 +785,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP
|
||||
return firewall.ProtocolTCP, nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP
|
||||
return firewall.ProtocolUDP, nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP
|
||||
return firewall.ProtocolICMP, nftypes.ICMP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||
}
|
||||
}
|
||||
|
||||
@@ -767,7 +822,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return m.tcpTracker.IsValidInbound(
|
||||
@@ -776,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
getTCPFlags(&d.tcp),
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeUDP:
|
||||
@@ -784,6 +840,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
size,
|
||||
)
|
||||
|
||||
case layers.LayerTypeICMPv4:
|
||||
@@ -791,8 +848,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
d.icmp4.TypeCode.Type(),
|
||||
size,
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
@@ -812,25 +869,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
if m.isSpecialICMP(d) {
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return filter
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||
return filter
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||
return filter
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
// Default policy: DROP ALL
|
||||
return true
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
@@ -850,15 +909,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
@@ -868,39 +927,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.udpHook(packetData), true
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop, true
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
return rule.action == firewall.ActionAccept
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
@@ -940,36 +996,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@@ -1017,20 +1069,21 @@ func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.forwarder == nil {
|
||||
fwder := m.forwarder.Load()
|
||||
if fwder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
fwder.Stop()
|
||||
m.forwarder.Store(nil)
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
|
||||
@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.ActionAccept, "", "allow all")
|
||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.ActionDrop, "", "default drop")
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -158,9 +169,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -203,9 +214,9 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.processOutgoingHooks(testOut)
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn)
|
||||
manager.dropFilter(testIn, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -251,9 +262,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -450,9 +461,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Setup scenario
|
||||
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack)
|
||||
manager.dropFilter(synack, 0)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -577,9 +588,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -668,9 +676,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -787,9 +792,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -875,9 +877,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
nil,
|
||||
r.port,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
@@ -26,20 +26,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false)
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
tc.ruleProto,
|
||||
tc.ruleSrcPort,
|
||||
tc.ruleDstPort,
|
||||
tc.ruleAction,
|
||||
"",
|
||||
tc.name,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
@@ -302,15 +302,15 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false)
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled)
|
||||
require.False(tb, manager.nativeRouter)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
|
||||
tb.Cleanup(func() {
|
||||
require.NoError(tb, manager.Reset(nil))
|
||||
require.NoError(tb, manager.Close(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
tc.rule.dest,
|
||||
tc.rule.proto,
|
||||
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
var rules []fw.Rule
|
||||
for _, r := range tc.rules {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
|
||||
for i, p := range tc.packets {
|
||||
srcIP := net.ParseIP(p.srcIP)
|
||||
dstIP := net.ParseIP(p.dstIP)
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,15 +18,17 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
@@ -50,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
func (i *IFaceMock) Address() wgaddr.Address {
|
||||
if i.AddressFunc == nil {
|
||||
return iface.WGAddress{}
|
||||
return wgaddr.Address{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
ip := netip.MustParseAddr("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: net.IPv4(10, 168, 0, 1),
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: net.IPv6loopback,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -246,15 +248,14 @@ func TestManagerReset(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Reset(nil)
|
||||
err = m.Close(nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
@@ -268,8 +269,8 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false)
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -328,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
|
||||
if err = m.Reset(nil); err != nil {
|
||||
if err = m.Close(nil); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -347,17 +347,17 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false)
|
||||
manager, err := Create(iface, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -401,9 +401,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.decoders = sync.Pool{
|
||||
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
net.ParseIP("100.10.0.100"),
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
hookCalled = true
|
||||
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.processOutgoingHooks(buf.Bytes())
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.processOutgoingHooks(buf.Bytes())
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -479,12 +479,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false)
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
if err := manager.Close(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false)
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -530,12 +530,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set up packet parameters
|
||||
srcIP := net.ParseIP("100.10.0.1")
|
||||
dstIP := net.ParseIP("100.10.0.100")
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||
srcPort := uint16(51334)
|
||||
dstPort := uint16(53)
|
||||
|
||||
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
outboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcIP: srcIP.AsSlice(),
|
||||
DstIP: dstIP.AsSlice(),
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
outboundUDP := &layers.UDP{
|
||||
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
||||
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||
|
||||
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
inboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: dstIP, // Original destination is now source
|
||||
DstIP: srcIP, // Original source is now destination
|
||||
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
inboundUDP := &layers.UDP{
|
||||
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
@@ -14,6 +13,8 @@ import (
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -52,9 +53,10 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// force IPv4
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeAddr] = conn
|
||||
b.endpoints[fakeIP] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
|
||||
return fakeUDPAddr, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
log.Warnf("failed to convert IP to netip.Addr")
|
||||
return
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
b.endpointsMu.Lock()
|
||||
defer b.endpointsMu.Unlock()
|
||||
delete(b.endpoints, fakeAddr)
|
||||
|
||||
delete(b.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
newAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||
Port: peerAddress.Port,
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
}
|
||||
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
@@ -118,6 +122,7 @@ type udpConn struct {
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
|
||||
@@ -13,11 +13,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
type WGTunDevice struct {
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -31,7 +32,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
|
||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) WgAddress() WGAddress {
|
||||
func (t *WGTunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,11 +13,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -29,7 +30,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -10,16 +11,16 @@ import (
|
||||
// PacketFilter interface for firewall abilities
|
||||
type PacketFilter interface {
|
||||
// DropOutgoing filter outgoing packets from host to external destinations
|
||||
DropOutgoing(packetData []byte) bool
|
||||
DropOutgoing(packetData []byte, size int) bool
|
||||
|
||||
// DropIncoming filter incoming packets from external sources to host
|
||||
DropIncoming(packetData []byte) bool
|
||||
DropIncoming(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.DropIncoming(buf[offset:]) {
|
||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
return 1, nil
|
||||
})
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
|
||||
@@ -14,11 +14,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
iceBind *bind.ICEBind
|
||||
@@ -30,7 +31,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
||||
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,12 +14,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
wgPort int
|
||||
key string
|
||||
mtu int
|
||||
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
|
||||
filterFn bind.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &TunKernelDevice{
|
||||
ctx: ctx,
|
||||
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: rawSock,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
UDPConn: rawSock,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
}
|
||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||
go mux.ReadFromConn(t.ctx)
|
||||
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return t.udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
|
||||
return closErr
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) WgAddress() WGAddress {
|
||||
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
return &TunNetstackDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
|
||||
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) WgAddress() WGAddress {
|
||||
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type USPDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -28,7 +29,7 @@ type USPDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
return &USPDevice{
|
||||
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) WgAddress() WGAddress {
|
||||
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -13,13 +13,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
||||
|
||||
type TunDevice struct {
|
||||
name string
|
||||
address WGAddress
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu int
|
||||
@@ -32,7 +33,7 @@ type TunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||
return &TunDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
||||
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||
t.address = address
|
||||
return t.assignAddr()
|
||||
}
|
||||
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (t *TunDevice) WgAddress() WGAddress {
|
||||
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||
return t.address
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/freebsd"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
link, err := freebsd.LinkByName(l.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(l, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
@@ -28,8 +29,6 @@ const (
|
||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||
)
|
||||
|
||||
type WGAddress = device.WGAddress
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
Free() error
|
||||
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
|
||||
}
|
||||
|
||||
// Address returns the interface address
|
||||
func (w *WGIface) Address() device.WGAddress {
|
||||
func (w *WGIface) Address() wgaddr.Address {
|
||||
return w.tun.WgAddress()
|
||||
}
|
||||
|
||||
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
addr, err := device.ParseWGAddress(newAddr)
|
||||
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,17 +3,18 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -6,17 +6,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -5,17 +5,18 @@ package iface
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
return wgIFace, nil
|
||||
}
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
|
||||
@@ -4,16 +4,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -6,6 +6,7 @@ package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
}
|
||||
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||
}
|
||||
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
|
||||
@@ -55,7 +55,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
|
||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
|
||||
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||
}
|
||||
if skipProxy {
|
||||
return nsTunDev, tunNet, nil
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
package device
|
||||
package wgaddr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// WGAddress WireGuard parsed address
|
||||
type WGAddress struct {
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (WGAddress, error) {
|
||||
func ParseWGAddress(address string) (Address, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
if err != nil {
|
||||
return WGAddress{}, err
|
||||
return Address{}, err
|
||||
}
|
||||
return WGAddress{
|
||||
return Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr WGAddress) String() string {
|
||||
func (addr Address) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -16,13 +17,13 @@ import (
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
|
||||
wgAddr *net.UDPAddr
|
||||
wgEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
fakeNetIP *netip.AddrPort
|
||||
wgBindEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
@@ -33,20 +34,24 @@ type ProxyBind struct {
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
||||
fakeNetIP, err := fakeAddress(nbAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.wgAddr = addr
|
||||
p.wgEndpoint = addrToEndpoint(addr)
|
||||
p.fakeNetIP = fakeNetIP
|
||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return err
|
||||
return nil
|
||||
|
||||
}
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgAddr
|
||||
return &net.UDPAddr{
|
||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||
Port: int(p.fakeNetIP.Port()),
|
||||
Zone: p.fakeNetIP.Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Work() {
|
||||
@@ -54,6 +59,8 @@ func (p *ProxyBind) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
@@ -93,7 +100,7 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
@@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgEndpoint,
|
||||
Endpoint: p.wgBindEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
@@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse new IP: %w", err)
|
||||
}
|
||||
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
|
||||
!define INSTALLER_NAME "netbird-installer.exe"
|
||||
!define MAIN_APP_EXE "Netbird"
|
||||
!define ICON "ui\\netbird.ico"
|
||||
!define BANNER "ui\\banner.bmp"
|
||||
!define ICON "ui\\assets\\netbird.ico"
|
||||
!define BANNER "ui\\build\\banner.bmp"
|
||||
!define LICENSE_DATA "..\\LICENSE"
|
||||
|
||||
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
|
||||
@@ -22,6 +22,8 @@
|
||||
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
||||
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
||||
|
||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||
|
||||
Unicode True
|
||||
|
||||
######################################################################
|
||||
@@ -68,6 +70,9 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_PAGE_DIRECTORY
|
||||
|
||||
; Custom page for autostart checkbox
|
||||
Page custom AutostartPage AutostartPageLeave
|
||||
|
||||
!insertmacro MUI_PAGE_INSTFILES
|
||||
|
||||
!insertmacro MUI_PAGE_FINISH
|
||||
@@ -80,8 +85,36 @@ ShowInstDetails Show
|
||||
|
||||
!insertmacro MUI_LANGUAGE "English"
|
||||
|
||||
; Variables for autostart option
|
||||
Var AutostartCheckbox
|
||||
Var AutostartEnabled
|
||||
|
||||
######################################################################
|
||||
|
||||
; Function to create the autostart options page
|
||||
Function AutostartPage
|
||||
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
|
||||
|
||||
nsDialogs::Create 1018
|
||||
Pop $0
|
||||
|
||||
${If} $0 == error
|
||||
Abort
|
||||
${EndIf}
|
||||
|
||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||
Pop $AutostartCheckbox
|
||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||
|
||||
nsDialogs::Show
|
||||
FunctionEnd
|
||||
|
||||
; Function to handle leaving the autostart page
|
||||
Function AutostartPageLeave
|
||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||
FunctionEnd
|
||||
|
||||
Function GetAppFromCommand
|
||||
Exch $1
|
||||
Push $2
|
||||
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
EnVar::SetHKLM
|
||||
EnVar::AddValueEx "path" "$INSTDIR"
|
||||
|
||||
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
|
||||
# kill ui client
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
|
||||
# wait the service uninstall take unblock the executable
|
||||
Sleep 3000
|
||||
|
||||
@@ -240,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
||||
addedRule, err := d.firewall.AddRouteFiltering(rule.Id, sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
rules, err = d.addInRules(r.Id, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
rules, err = d.addOutRules(r.Id, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) id.RuleID {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -8,11 +9,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
@@ -45,20 +49,20 @@ func TestDefaultManager(t *testing.T) {
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset(nil)
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
@@ -339,20 +343,20 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset(nil)
|
||||
_ = fw.Close(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
||||
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
|
||||
}
|
||||
|
||||
// Address mocks base method.
|
||||
func (m *MockIFaceMapper) Address() iface.WGAddress {
|
||||
func (m *MockIFaceMapper) Address() wgaddr.Address {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Address")
|
||||
ret0, _ := ret[0].(iface.WGAddress)
|
||||
ret0, _ := ret[0].(wgaddr.Address)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
runningChanOpen := true
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if c.isContextCancelled() {
|
||||
if c.ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil && runningChanOpen {
|
||||
runningChan <- nil
|
||||
close(runningChan)
|
||||
runningChanOpen = false
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case runningChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) isContextCancelled() bool {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||
// When enabled, the last received network map will be stored and can be retrieved
|
||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -29,6 +31,8 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
}
|
||||
@@ -37,9 +41,9 @@ func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() iface.WGAddress {
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return iface.WGAddress{
|
||||
return wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}
|
||||
@@ -455,7 +459,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
@@ -916,7 +920,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface, false)
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
@@ -1015,7 +1019,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
mh.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Reset mocks
|
||||
// Close mocks
|
||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||
mh.ExpectedCalls = nil
|
||||
mh.Calls = nil
|
||||
|
||||
@@ -2,7 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
@@ -5,15 +5,15 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
|
||||
@@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
|
||||
@@ -35,7 +35,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
@@ -191,7 +191,7 @@ type Engine struct {
|
||||
persistNetworkMap bool
|
||||
latestNetworkMap *mgmProto.NetworkMap
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager types.FlowManager
|
||||
flowManager nftypes.FlowManager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -234,7 +234,6 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
flowManager: netflow.NewManager(clientCtx),
|
||||
}
|
||||
if runtime.GOOS == "ios" {
|
||||
if !fileExists(mobileDep.StateFilePath) {
|
||||
@@ -303,8 +302,6 @@ func (e *Engine) Stop() error {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
}
|
||||
|
||||
e.flowManager.Close()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
@@ -314,6 +311,12 @@ func (e *Engine) Stop() error {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
|
||||
// stop flow manager after wg interface is gone
|
||||
if e.flowManager != nil {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
@@ -348,6 +351,10 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@@ -454,7 +461,7 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
@@ -488,13 +495,13 @@ func (e *Engine) initFirewall() error {
|
||||
|
||||
// this rule is static and will be torn down on engine down by the firewall manager
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
firewallManager.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
firewallManager.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
); err != nil {
|
||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||
return nil
|
||||
@@ -518,6 +525,7 @@ func (e *Engine) blockLanAccess() {
|
||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
for _, network := range toBlock {
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{v4},
|
||||
network,
|
||||
firewallManager.ProtocolALL,
|
||||
@@ -721,12 +729,13 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
||||
return e.flowManager.Update(flowConfig)
|
||||
}
|
||||
|
||||
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*types.FlowConfig, error) {
|
||||
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||
if config.GetInterval() == nil {
|
||||
return nil, errors.New("flow interval is nil")
|
||||
}
|
||||
return &types.FlowConfig{
|
||||
return &nftypes.FlowConfig{
|
||||
Enabled: config.GetEnabled(),
|
||||
Counters: config.GetCounters(),
|
||||
URL: config.GetUrl(),
|
||||
TokenPayload: config.GetTokenPayload(),
|
||||
TokenSignature: config.GetTokenSignature(),
|
||||
@@ -1419,7 +1428,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Reset(e.stateManager)
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
if err != nil {
|
||||
log.Warnf("failed to reset firewall: %s", err)
|
||||
}
|
||||
@@ -1632,16 +1641,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
log.Info("restarting engine")
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
if e.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("restarting engine")
|
||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
log.Infof("cancelling client, engine will be recreated")
|
||||
log.Infof("cancelling client context, engine will be recreated")
|
||||
e.clientCancel()
|
||||
}
|
||||
|
||||
@@ -1653,34 +1665,17 @@ func (e *Engine) startNetworkMonitor() {
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
var mu sync.Mutex
|
||||
var debounceTimer *time.Timer
|
||||
|
||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
||||
// a network change is detected, or an error occurs on start up
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
// This function is called when a network change is detected
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
||||
debounceTimer.Stop()
|
||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Infof("network monitor stopped")
|
||||
return
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
log.Errorf("network monitor error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -75,7 +76,7 @@ type MockWGIface struct {
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
AddressFunc func() wgaddr.Address
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
@@ -114,7 +115,7 @@ func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
func (m *MockWGIface) Address() wgaddr.Address {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
@@ -364,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
Address() wgaddr.Address
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
|
||||
306
client/internal/netflow/conntrack/conntrack.go
Normal file
306
client/internal/netflow/conntrack/conntrack.go
Normal file
@@ -0,0 +1,306 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
nfct "github.com/ti-mo/conntrack"
|
||||
"github.com/ti-mo/netfilter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
|
||||
// ConnTrack manages kernel-based conntrack events
|
||||
type ConnTrack struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
iface nftypes.IFaceMapper
|
||||
|
||||
conn *nfct.Conn
|
||||
mux sync.Mutex
|
||||
|
||||
instanceID uuid.UUID
|
||||
started bool
|
||||
done chan struct{}
|
||||
sysctlModified bool
|
||||
}
|
||||
|
||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||
return &ConnTrack{
|
||||
flowLogger: flowLogger,
|
||||
iface: iface,
|
||||
instanceID: uuid.New(),
|
||||
started: false,
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||
func (c *ConnTrack) Start(enableCounters bool) error {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("Starting conntrack event listening")
|
||||
|
||||
if enableCounters {
|
||||
c.EnableAccounting()
|
||||
}
|
||||
|
||||
conn, err := nfct.Dial(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial conntrack: %w", err)
|
||||
}
|
||||
c.conn = conn
|
||||
|
||||
events := make(chan nfct.Event, defaultChannelSize)
|
||||
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
|
||||
netfilter.GroupCTNew,
|
||||
netfilter.GroupCTDestroy,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Errorf("Error closing conntrack connection: %v", err)
|
||||
}
|
||||
c.conn = nil
|
||||
return fmt.Errorf("start conntrack listener: %w", err)
|
||||
}
|
||||
|
||||
c.started = true
|
||||
|
||||
go c.receiverRoutine(events, errChan)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) {
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
c.handleEvent(event)
|
||||
case err := <-errChan:
|
||||
log.Errorf("Error from conntrack event listener: %v", err)
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Errorf("Error closing conntrack connection: %v", err)
|
||||
}
|
||||
return
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the connection tracking. This method is idempotent.
|
||||
func (c *ConnTrack) Stop() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if !c.started {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Stopping conntrack event listening")
|
||||
|
||||
select {
|
||||
case c.done <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Errorf("Error closing conntrack connection: %v", err)
|
||||
}
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
c.started = false
|
||||
|
||||
c.RestoreAccounting()
|
||||
}
|
||||
|
||||
// Close stops listening for events and cleans up resources
|
||||
func (c *ConnTrack) Close() error {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.started {
|
||||
select {
|
||||
case c.done <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
c.started = false
|
||||
|
||||
c.RestoreAccounting()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("close conntrack: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleEvent processes incoming conntrack events
|
||||
func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
if event.Flow == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if event.Type != nfct.EventNew && event.Type != nfct.EventDestroy {
|
||||
return
|
||||
}
|
||||
|
||||
flow := *event.Flow
|
||||
|
||||
proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol)
|
||||
if proto == nftypes.ProtocolUnknown {
|
||||
return
|
||||
}
|
||||
srcIP := flow.TupleOrig.IP.SourceAddress
|
||||
dstIP := flow.TupleOrig.IP.DestinationAddress
|
||||
|
||||
if !c.relevantFlow(srcIP, dstIP) {
|
||||
return
|
||||
}
|
||||
|
||||
var srcPort, dstPort uint16
|
||||
var icmpType, icmpCode uint8
|
||||
|
||||
switch proto {
|
||||
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
|
||||
srcPort = flow.TupleOrig.Proto.SourcePort
|
||||
dstPort = flow.TupleOrig.Proto.DestinationPort
|
||||
case nftypes.ICMP:
|
||||
icmpType = flow.TupleOrig.Proto.ICMPType
|
||||
icmpCode = flow.TupleOrig.Proto.ICMPCode
|
||||
}
|
||||
|
||||
flowID := c.getFlowID(flow.ID)
|
||||
direction := c.inferDirection(srcIP, dstIP)
|
||||
|
||||
eventType := nftypes.TypeStart
|
||||
eventStr := "New"
|
||||
|
||||
if event.Type == nfct.EventDestroy {
|
||||
eventType = nftypes.TypeEnd
|
||||
eventStr = "Ended"
|
||||
}
|
||||
|
||||
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
c.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: eventType,
|
||||
Direction: direction,
|
||||
Protocol: proto,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
RxPackets: c.mapRxPackets(flow, direction),
|
||||
TxPackets: c.mapTxPackets(flow, direction),
|
||||
RxBytes: c.mapRxBytes(flow, direction),
|
||||
TxBytes: c.mapTxBytes(flow, direction),
|
||||
})
|
||||
}
|
||||
|
||||
// relevantFlow checks if the flow is related to the specified interface
|
||||
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
|
||||
// TODO: filter traffic by interface
|
||||
|
||||
wgnet := c.iface.Address().Network
|
||||
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
func (c *ConnTrack) mapRxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||
// For Ingress: CountersOrig is from external to us (RX)
|
||||
// For Egress: CountersReply is from external to us (RX)
|
||||
if direction == nftypes.Ingress {
|
||||
return flow.CountersOrig.Packets
|
||||
}
|
||||
return flow.CountersReply.Packets
|
||||
}
|
||||
|
||||
// mapTxPackets maps packet counts to TX based on flow direction
|
||||
func (c *ConnTrack) mapTxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||
// For Ingress: CountersReply is from us to external (TX)
|
||||
// For Egress: CountersOrig is from us to external (TX)
|
||||
if direction == nftypes.Ingress {
|
||||
return flow.CountersReply.Packets
|
||||
}
|
||||
return flow.CountersOrig.Packets
|
||||
}
|
||||
|
||||
// mapRxBytes maps byte counts to RX based on flow direction
|
||||
func (c *ConnTrack) mapRxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||
// For Ingress: CountersOrig is from external to us (RX)
|
||||
// For Egress: CountersReply is from external to us (RX)
|
||||
if direction == nftypes.Ingress {
|
||||
return flow.CountersOrig.Bytes
|
||||
}
|
||||
return flow.CountersReply.Bytes
|
||||
}
|
||||
|
||||
// mapTxBytes maps byte counts to TX based on flow direction
|
||||
func (c *ConnTrack) mapTxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||
// For Ingress: CountersReply is from us to external (TX)
|
||||
// For Egress: CountersOrig is from us to external (TX)
|
||||
if direction == nftypes.Ingress {
|
||||
return flow.CountersReply.Bytes
|
||||
}
|
||||
return flow.CountersOrig.Bytes
|
||||
}
|
||||
|
||||
// getFlowID creates a unique UUID based on the conntrack ID and instance ID
|
||||
func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
|
||||
var buf [4]byte
|
||||
binary.BigEndian.PutUint32(buf[:], conntrackID)
|
||||
return uuid.NewSHA1(c.instanceID, buf[:])
|
||||
}
|
||||
|
||||
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||
|
||||
switch {
|
||||
case wgaddr.Equal(src):
|
||||
return nftypes.Egress
|
||||
case wgaddr.Equal(dst):
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(src):
|
||||
// netbird network -> resource network
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(dst):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
|
||||
// TODO: handle site2site traffic
|
||||
}
|
||||
|
||||
return nftypes.DirectionUnknown
|
||||
}
|
||||
9
client/internal/netflow/conntrack/conntrack_nonlinux.go
Normal file
9
client/internal/netflow/conntrack/conntrack_nonlinux.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package conntrack
|
||||
|
||||
import nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) nftypes.ConnTracker {
|
||||
return nil
|
||||
}
|
||||
73
client/internal/netflow/conntrack/sysctl.go
Normal file
73
client/internal/netflow/conntrack/sysctl.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// conntrackAcctPath is the sysctl path for conntrack accounting
|
||||
conntrackAcctPath = "net.netfilter.nf_conntrack_acct"
|
||||
)
|
||||
|
||||
// EnableAccounting ensures that connection tracking accounting is enabled in the kernel.
|
||||
func (c *ConnTrack) EnableAccounting() {
|
||||
// haven't restored yet
|
||||
if c.sysctlModified {
|
||||
return
|
||||
}
|
||||
|
||||
modified, err := setSysctl(conntrackAcctPath, 1)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to enable conntrack accounting: %v", err)
|
||||
return
|
||||
}
|
||||
c.sysctlModified = modified
|
||||
}
|
||||
|
||||
// RestoreAccounting restores the connection tracking accounting setting to its original value.
|
||||
func (c *ConnTrack) RestoreAccounting() {
|
||||
if !c.sysctlModified {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := setSysctl(conntrackAcctPath, 0); err != nil {
|
||||
log.Warnf("Failed to restore conntrack accounting: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.sysctlModified = false
|
||||
}
|
||||
|
||||
// setSysctl sets a sysctl configuration and returns whether it was modified.
|
||||
func setSysctl(key string, desiredValue int) (bool, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return false, fmt.Errorf("convert current value to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return false, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
return true, nil
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/store"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type rcvChan chan *types.EventFields
|
||||
@@ -21,15 +22,17 @@ type Logger struct {
|
||||
enabled atomic.Bool
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancelReceiver context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(ctx context.Context) *Logger {
|
||||
func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Logger{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
Store: store.NewMemoryStore(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
Store: store.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,13 +61,14 @@ func (l *Logger) startReceiver() {
|
||||
if l.enabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
l.mux.Lock()
|
||||
ctx, cancel := context.WithCancel(l.ctx)
|
||||
l.cancelReceiver = cancel
|
||||
l.mux.Unlock()
|
||||
|
||||
c := make(rcvChan, 100)
|
||||
l.rcvChan.Swap(&c)
|
||||
l.rcvChan.Store(&c)
|
||||
l.enabled.Store(true)
|
||||
|
||||
for {
|
||||
@@ -73,12 +77,15 @@ func (l *Logger) startReceiver() {
|
||||
log.Info("flow Memory store receiver stopped")
|
||||
return
|
||||
case eventFields := <-c:
|
||||
id := uuid.NewString()
|
||||
id := uuid.New()
|
||||
event := types.Event{
|
||||
ID: id,
|
||||
EventFields: *eventFields,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
|
||||
event.SourceResourceID = []byte(srcResId)
|
||||
event.DestResourceID = []byte(dstResId)
|
||||
l.Store.StoreEvent(&event)
|
||||
}
|
||||
}
|
||||
@@ -100,6 +107,7 @@ func (l *Logger) stop() {
|
||||
l.cancelReceiver()
|
||||
l.cancelReceiver = nil
|
||||
}
|
||||
l.rcvChan.Store(nil)
|
||||
l.mux.Unlock()
|
||||
}
|
||||
|
||||
@@ -107,6 +115,10 @@ func (l *Logger) GetEvents() []*types.Event {
|
||||
return l.Store.GetEvents()
|
||||
}
|
||||
|
||||
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
|
||||
l.Store.DeleteEvents(ids)
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.cancel()
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(context.Background())
|
||||
logger := logger.New(context.Background(), nil)
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
|
||||
@@ -2,47 +2,234 @@ package netflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/logger"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
)
|
||||
|
||||
// Manager handles netflow tracking and logging
|
||||
type Manager struct {
|
||||
mux sync.Mutex
|
||||
logger types.FlowLogger
|
||||
flowConfig *types.FlowConfig
|
||||
mux sync.Mutex
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
ctx context.Context
|
||||
receiverClient *client.GRPCClient
|
||||
publicKey []byte
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context) *Manager {
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
flowLogger := logger.New(ctx, statusRecorder)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
ct = conntrack.New(flowLogger, iface)
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
logger: logger.New(ctx),
|
||||
logger: flowLogger,
|
||||
conntrack: ct,
|
||||
ctx: ctx,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Update(update *types.FlowConfig) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
if update == nil {
|
||||
return nil
|
||||
// Update applies new flow configuration settings
|
||||
// needsNewClient checks if a new client needs to be created
|
||||
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||
current := m.flowConfig
|
||||
return previous == nil ||
|
||||
!previous.Enabled ||
|
||||
previous.TokenPayload != current.TokenPayload ||
|
||||
previous.TokenSignature != current.TokenSignature ||
|
||||
previous.URL != current.URL
|
||||
}
|
||||
|
||||
// enableFlow starts components for flow tracking
|
||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
// first make sender ready so events don't pile up
|
||||
if m.needsNewClient(previous) {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs(flowClient)
|
||||
go m.startSender()
|
||||
}
|
||||
|
||||
m.flowConfig = update
|
||||
m.logger.Enable()
|
||||
|
||||
if update.Enabled {
|
||||
m.logger.Enable()
|
||||
return nil
|
||||
if m.conntrack != nil {
|
||||
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
|
||||
return fmt.Errorf("start conntrack: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableFlow stops components for flow tracking
|
||||
func (m *Manager) disableFlow() error {
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update applies new flow configuration settings
|
||||
func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
if update == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
previous := m.flowConfig
|
||||
m.flowConfig = update
|
||||
|
||||
if update.Enabled {
|
||||
return m.enableFlow(previous)
|
||||
}
|
||||
|
||||
return m.disableFlow()
|
||||
}
|
||||
|
||||
// Close cleans up all resources
|
||||
func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Close()
|
||||
}
|
||||
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("failed to close receiver client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Close()
|
||||
}
|
||||
|
||||
func (m *Manager) GetLogger() types.FlowLogger {
|
||||
// GetLogger returns the flow logger
|
||||
func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||
return m.logger
|
||||
}
|
||||
|
||||
func (m *Manager) startSender() {
|
||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
events := m.logger.GetEvents()
|
||||
for _, event := range events {
|
||||
if err := m.send(event); err != nil {
|
||||
log.Errorf("failed to send flow event to server: %s", err)
|
||||
continue
|
||||
}
|
||||
log.Tracef("sent flow event: %s", event.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
log.Tracef("received flow event ack: %s", ack.EventId)
|
||||
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("failed to receive flow event ack: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) send(event *nftypes.Event) error {
|
||||
m.mux.Lock()
|
||||
client := m.receiverClient
|
||||
m.mux.Unlock()
|
||||
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return client.Send(toProtoEvent(m.publicKey, event))
|
||||
}
|
||||
|
||||
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
|
||||
protoEvent := &proto.FlowEvent{
|
||||
EventId: event.ID[:],
|
||||
Timestamp: timestamppb.New(event.Timestamp),
|
||||
PublicKey: publicKey,
|
||||
FlowFields: &proto.FlowFields{
|
||||
FlowId: event.FlowID[:],
|
||||
RuleId: event.RuleID,
|
||||
Type: proto.Type(event.Type),
|
||||
Direction: proto.Direction(event.Direction),
|
||||
Protocol: uint32(event.Protocol),
|
||||
SourceIp: event.SourceIP.AsSlice(),
|
||||
DestIp: event.DestIP.AsSlice(),
|
||||
RxPackets: event.RxPackets,
|
||||
TxPackets: event.TxPackets,
|
||||
RxBytes: event.RxBytes,
|
||||
TxBytes: event.TxBytes,
|
||||
SourceResourceId: event.SourceResourceID,
|
||||
DestResourceId: event.DestResourceID,
|
||||
},
|
||||
}
|
||||
|
||||
if event.Protocol == nftypes.ICMP {
|
||||
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
|
||||
IcmpInfo: &proto.ICMPInfo{
|
||||
IcmpType: uint32(event.ICMPType),
|
||||
IcmpCode: uint32(event.ICMPCode),
|
||||
},
|
||||
}
|
||||
return protoEvent
|
||||
}
|
||||
|
||||
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_PortInfo{
|
||||
PortInfo: &proto.PortInfo{
|
||||
SourcePort: uint32(event.SourcePort),
|
||||
DestPort: uint32(event.DestPort),
|
||||
},
|
||||
}
|
||||
|
||||
return protoEvent
|
||||
}
|
||||
|
||||
@@ -3,18 +3,22 @@ package store
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
func NewMemoryStore() *Memory {
|
||||
return &Memory{
|
||||
events: make(map[string]*types.Event),
|
||||
events: make(map[uuid.UUID]*types.Event),
|
||||
}
|
||||
}
|
||||
|
||||
type Memory struct {
|
||||
mux sync.Mutex
|
||||
events map[string]*types.Event
|
||||
events map[uuid.UUID]*types.Event
|
||||
}
|
||||
|
||||
func (m *Memory) StoreEvent(event *types.Event) {
|
||||
@@ -26,7 +30,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
|
||||
func (m *Memory) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.events = make(map[string]*types.Event)
|
||||
maps.Clear(m.events)
|
||||
}
|
||||
|
||||
func (m *Memory) GetEvents() []*types.Event {
|
||||
@@ -39,7 +43,7 @@ func (m *Memory) GetEvents() []*types.Event {
|
||||
return events
|
||||
}
|
||||
|
||||
func (m *Memory) DeleteEvents(ids []string) {
|
||||
func (m *Memory) DeleteEvents(ids []uuid.UUID) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, id := range ids {
|
||||
|
||||
@@ -2,48 +2,98 @@ package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type Protocol uint8
|
||||
|
||||
const (
|
||||
ProtocolUnknown = Protocol(0)
|
||||
ICMP = Protocol(1)
|
||||
TCP = Protocol(6)
|
||||
UDP = Protocol(17)
|
||||
SCTP = Protocol(132)
|
||||
)
|
||||
|
||||
func (p Protocol) String() string {
|
||||
switch p {
|
||||
case 1:
|
||||
return "ICMP"
|
||||
case 6:
|
||||
return "TCP"
|
||||
case 17:
|
||||
return "UDP"
|
||||
case 132:
|
||||
return "SCTP"
|
||||
default:
|
||||
return strconv.FormatUint(uint64(p), 10)
|
||||
}
|
||||
}
|
||||
|
||||
type Type int
|
||||
|
||||
const (
|
||||
TypeStart = iota
|
||||
TypeUnknown = Type(iota)
|
||||
TypeStart
|
||||
TypeEnd
|
||||
TypeDrop
|
||||
)
|
||||
|
||||
type Direction int
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case Ingress:
|
||||
return "ingress"
|
||||
case Egress:
|
||||
return "egress"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
Ingress = iota
|
||||
DirectionUnknown = Direction(iota)
|
||||
Ingress
|
||||
Egress
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
ID string
|
||||
ID uuid.UUID
|
||||
Timestamp time.Time
|
||||
EventFields
|
||||
}
|
||||
|
||||
type EventFields struct {
|
||||
FlowID uuid.UUID
|
||||
Type Type
|
||||
Direction Direction
|
||||
Protocol uint8
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
FlowID uuid.UUID
|
||||
Type Type
|
||||
RuleID []byte
|
||||
Direction Direction
|
||||
Protocol Protocol
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
SourceResourceID []byte
|
||||
DestResourceID []byte
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
RxPackets uint64
|
||||
TxPackets uint64
|
||||
RxBytes uint64
|
||||
TxBytes uint64
|
||||
}
|
||||
|
||||
type FlowConfig struct {
|
||||
URL string
|
||||
Interval time.Duration
|
||||
Enabled bool
|
||||
Counters bool
|
||||
TokenPayload string
|
||||
TokenSignature string
|
||||
}
|
||||
@@ -62,6 +112,8 @@ type FlowLogger interface {
|
||||
StoreEvent(flowEvent EventFields)
|
||||
// GetEvents returns all stored events
|
||||
GetEvents() []*Event
|
||||
// DeleteEvents deletes events from the store
|
||||
DeleteEvents([]uuid.UUID)
|
||||
// Close closes the logger
|
||||
Close()
|
||||
// Enable enables the flow logger receiver
|
||||
@@ -76,7 +128,24 @@ type Store interface {
|
||||
// GetEvents returns all stored events
|
||||
GetEvents() []*Event
|
||||
// DeleteEvents deletes events from the store
|
||||
DeleteEvents([]string)
|
||||
DeleteEvents([]uuid.UUID)
|
||||
// Close closes the store
|
||||
Close()
|
||||
}
|
||||
|
||||
// ConnTracker defines the interface for connection tracking functionality
|
||||
type ConnTracker interface {
|
||||
// Start begins tracking connections by listening for conntrack events.
|
||||
Start(bool) error
|
||||
// Stop stops the connection tracking.
|
||||
Stop()
|
||||
// Close stops listening for events and cleans up resources
|
||||
Close() error
|
||||
}
|
||||
|
||||
// IFaceMapper provides interface to check if we're using userspace WireGuard
|
||||
type IFaceMapper interface {
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
@@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Debugf("Network monitor: closed routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
@@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
@@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
|
||||
return ctx.Err()
|
||||
// handle route changes
|
||||
case route := <-routeChan:
|
||||
// default route and main table
|
||||
@@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
// triggered on added/replaced routes
|
||||
case syscall.RTM_NEWROUTE:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route monitor: %w", err)
|
||||
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
return ctx.Err()
|
||||
case route := <-routeMonitor.RouteUpdates():
|
||||
if route.Destination.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if routeChanged(route, nexthopv4, nexthopv6, callback) {
|
||||
break
|
||||
if routeChanged(route, nexthopv4, nexthopv6) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
|
||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
@@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
|
||||
case systemops.RouteModified:
|
||||
// TODO: get routing table to figure out if our route is affected for modified routes
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
case systemops.RouteAdded:
|
||||
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
|
||||
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
}
|
||||
case systemops.RouteDeleted:
|
||||
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
|
||||
go callback()
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,27 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
var ErrStopped = errors.New("monitor has been stopped")
|
||||
const (
|
||||
debounceTime = 2 * time.Second
|
||||
)
|
||||
|
||||
var checkChangeFn = checkChange
|
||||
|
||||
// NetworkMonitor watches for changes in network configuration.
|
||||
type NetworkMonitor struct {
|
||||
@@ -19,3 +34,99 @@ type NetworkMonitor struct {
|
||||
func New() *NetworkMonitor {
|
||||
return &NetworkMonitor{}
|
||||
}
|
||||
|
||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
nw.mu.Lock()
|
||||
if nw.cancel != nil {
|
||||
nw.mu.Unlock()
|
||||
return errors.New("network monitor already started")
|
||||
}
|
||||
|
||||
ctx, nw.cancel = context.WithCancel(ctx)
|
||||
defer nw.cancel()
|
||||
nw.wg.Add(1)
|
||||
nw.mu.Unlock()
|
||||
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
return nil
|
||||
}
|
||||
|
||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||
}
|
||||
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
event := make(chan struct{}, 1)
|
||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||
|
||||
// debounce changes
|
||||
timer := time.NewTimer(0)
|
||||
timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-event:
|
||||
timer.Reset(debounceTime)
|
||||
case <-timer.C:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the network monitor.
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
nw.mu.Lock()
|
||||
defer nw.mu.Unlock()
|
||||
|
||||
if nw.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
nw.cancel()
|
||||
nw.wg.Wait()
|
||||
}
|
||||
|
||||
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
|
||||
for {
|
||||
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
|
||||
close(event)
|
||||
return
|
||||
}
|
||||
// prevent blocking
|
||||
select {
|
||||
case event <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
nw.mu.Lock()
|
||||
ctx, nw.cancel = context.WithCancel(ctx)
|
||||
nw.mu.Unlock()
|
||||
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
return nil
|
||||
}
|
||||
|
||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||
}
|
||||
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the network monitor.
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
nw.mu.Lock()
|
||||
defer nw.mu.Unlock()
|
||||
|
||||
if nw.cancel != nil {
|
||||
nw.cancel()
|
||||
nw.wg.Wait()
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user